반응형
Notice
Recent Posts
Recent Comments
Link
관리 메뉴

bro's coding

GAN.train.camel 본문

[AI]/GAN

GAN.train.camel

givemebro 2020. 7. 14. 14:31
반응형

discriminator : 진짜 이미지는 진짜 이미지로 판단, 가짜 이미지는 가짜 이미지로 판단 하는것이 목표
generator : 랜덤한 노이즈 값을 입력 받아서 가짜 이미지를 생성, discriminator이 어떻게 판단하는지에 따라 진짜 이미지 같이 만드는것이 목표

batch_size : 64
epochs : 6000

https://broscoding.tistory.com/323

 

GAN.source

from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D from keras.layers.merge impo..

broscoding.tistory.com

import os
import matplotlib.pyplot as plt

from models.GAN import GAN
from utils.loaders import load_safari

# 낙타 그림 80,000장
# 28 X 28

 

# run params
SECTION = 'gan'
RUN_ID = '0001'
DATA_NAME = 'camel'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

 

(x_train, y_train) = load_safari(DATA_NAME)
x_train.shape
(80000, 28, 28, 1)

 

plt.imshow(x_train[200,:,:,0], cmap = 'gray')

 

# model

gan = GAN(input_dim = (28,28,1)
        , discriminator_conv_filters = [64,64,128,128]
        , discriminator_conv_kernel_size = [5,5,5,5]
        , discriminator_conv_strides = [2,2,2,1]
        , discriminator_batch_norm_momentum = None
        , discriminator_activation = 'relu'
        , discriminator_dropout_rate = 0.4
        , discriminator_learning_rate = 0.0008
        , generator_initial_dense_layer_size = (7, 7, 64)
        , generator_upsample = [2,2, 1, 1]
        , generator_conv_filters = [128,64, 64,1]
        , generator_conv_kernel_size = [5,5,5,5]
        , generator_conv_strides = [1,1, 1, 1]
        , generator_batch_norm_momentum = 0.9
        , generator_activation = 'relu'
        , generator_dropout_rate = None
        , generator_learning_rate = 0.0004
        , optimiser = 'rmsprop'
        , z_dim = 100
        )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

 

gan.discriminator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
discriminator_input (InputLa (None, 28, 28, 1)         0         
_________________________________________________________________
discriminator_conv_0 (Conv2D (None, 14, 14, 64)        1664      
_________________________________________________________________
activation_1 (Activation)    (None, 14, 14, 64)        0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 14, 14, 64)        0         
_________________________________________________________________
discriminator_conv_1 (Conv2D (None, 7, 7, 64)          102464    
_________________________________________________________________
activation_2 (Activation)    (None, 7, 7, 64)          0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 7, 7, 64)          0         
_________________________________________________________________
discriminator_conv_2 (Conv2D (None, 4, 4, 128)         204928    
_________________________________________________________________
activation_3 (Activation)    (None, 4, 4, 128)         0         
_________________________________________________________________
dropout_3 (Dropout)          (None, 4, 4, 128)         0         
_________________________________________________________________
discriminator_conv_3 (Conv2D (None, 4, 4, 128)         409728    
_________________________________________________________________
activation_4 (Activation)    (None, 4, 4, 128)         0         
_________________________________________________________________
dropout_4 (Dropout)          (None, 4, 4, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 2048)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 720,833
Trainable params: 720,833
Non-trainable params: 0
_________________________________________________________________

 

gan.generator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator_input (InputLayer) (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 3136)              316736    
_________________________________________________________________
batch_normalization_1 (Batch (None, 3136)              12544     
_________________________________________________________________
activation_5 (Activation)    (None, 3136)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 14, 14, 64)        0         
_________________________________________________________________
generator_conv_0 (Conv2D)    (None, 14, 14, 128)       204928    
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
activation_6 (Activation)    (None, 14, 14, 128)       0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 28, 28, 128)       0         
_________________________________________________________________
generator_conv_1 (Conv2D)    (None, 28, 28, 64)        204864    
_________________________________________________________________
batch_normalization_3 (Batch (None, 28, 28, 64)        256       
_________________________________________________________________
activation_7 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
generator_conv_2 (Conv2DTran (None, 28, 28, 64)        102464    
_________________________________________________________________
batch_normalization_4 (Batch (None, 28, 28, 64)        256       
_________________________________________________________________
activation_8 (Activation)    (None, 28, 28, 64)        0         
_________________________________________________________________
generator_conv_3 (Conv2DTran (None, 28, 28, 1)         1601      
_________________________________________________________________
activation_9 (Activation)    (None, 28, 28, 1)         0         
=================================================================
Total params: 844,161
Trainable params: 837,377
Non-trainable params: 6,784
_________________________________________________________________

 

# train

BATCH_SIZE = 64
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
gan.train(     
    x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
)
0 [D loss: (0.717)(R 0.698, F 0.735)] [D acc: (0.172)(0.344, 0.000)] [G loss: 0.685] [G acc: 1.000]
1 [D loss: (0.839)(R 0.640, F 1.039)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.674] [G acc: 1.000]
2 [D loss: (1.021)(R 0.634, F 1.408)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.686] [G acc: 1.000]
3 [D loss: (0.696)(R 0.686, F 0.707)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.686] [G acc: 1.000]
4 [D loss: (0.695)(R 0.687, F 0.703)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.687] [G acc: 1.000]
5 [D loss: (0.694)(R 0.688, F 0.701)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.688] [G acc: 1.000]
6 [D loss: (0.693)(R 0.688, F 0.699)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.687] [G acc: 1.000]
7 [D loss: (0.694)(R 0.687, F 0.700)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.688] [G acc: 1.000]
8 [D loss: (0.694)(R 0.687, F 0.701)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.688] [G acc: 1.000]
9 [D loss: (0.694)(R 0.688, F 0.701)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.689] [G acc: 1.000]
10 [D loss: (0.694)(R 0.689, F 0.699)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.689] [G acc: 1.000]
11 [D loss: (0.694)(R 0.689, F 0.699)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.690] [G acc: 0.984]
12 [D loss: (0.694)(R 0.689, F 0.699)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.690] [G acc: 0.984]
13 [D loss: (0.694)(R 0.689, F 0.699)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.690] [G acc: 0.984]
14 [D loss: (0.694)(R 0.690, F 0.698)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.691] [G acc: 0.984]
15 [D loss: (0.694)(R 0.691, F 0.698)] [D acc: (0.477)(0.953, 0.000)] [G loss: 0.691] [G acc: 0.969]
16 [D loss: (0.694)(R 0.691, F 0.697)] [D acc: (0.492)(0.984, 0.000)] [G loss: 0.692] [G acc: 0.906]
17 [D loss: (0.694)(R 0.691, F 0.697)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.692] [G acc: 0.953]
18 [D loss: (0.694)(R 0.691, F 0.696)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.692] [G acc: 0.953]
19 [D loss: (0.694)(R 0.691, F 0.696)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.692] [G acc: 0.859]
20 [D loss: (0.694)(R 0.691, F 0.696)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.692] [G acc: 0.922]
21 [D loss: (0.694)(R 0.692, F 0.696)] [D acc: (0.500)(1.000, 0.000)] [G loss: 0.692] [G acc: 0.891]
22 [D loss: (0.694)(R 0.692, F 0.695)] [D acc: (0.477)(0.953, 0.000)] [G loss: 0.693] [G acc: 0.844]
23 [D loss: (0.694)(R 0.692, F 0.695)] [D acc: (0.477)(0.953, 0.000)] [G loss: 0.693] [G acc: 0.781]
24 [D loss: (0.694)(R 0.692, F 0.695)] [D acc: (0.438)(0.875, 0.000)] [G loss: 0.693] [G acc: 0.844]
25 [D loss: (0.693)(R 0.692, F 0.695)] [D acc: (0.453)(0.906, 0.000)] [G loss: 0.693] [G acc: 0.828]
26 [D loss: (0.693)(R 0.692, F 0.695)] [D acc: (0.492)(0.953, 0.031)] [G loss: 0.693] [G acc: 0.609]
27 [D loss: (0.694)(R 0.692, F 0.695)] [D acc: (0.430)(0.844, 0.016)] [G loss: 0.693] [G acc: 0.484]
28 [D loss: (0.693)(R 0.693, F 0.694)] [D acc: (0.477)(0.859, 0.094)] [G loss: 0.693] [G acc: 0.328]
29 [D loss: (0.693)(R 0.693, F 0.694)] [D acc: (0.422)(0.781, 0.062)] [G loss: 0.693] [G acc: 0.359]
30 [D loss: (0.693)(R 0.693, F 0.694)] [D acc: (0.398)(0.797, 0.000)] [G loss: 0.693] [G acc: 0.312]
.
.
.
5970 [D loss: (0.631)(R 0.613, F 0.648)] [D acc: (0.688)(0.641, 0.734)] [G loss: 1.676] [G acc: 0.078]
5971 [D loss: (0.500)(R 0.611, F 0.389)] [D acc: (0.711)(0.531, 0.891)] [G loss: 1.683] [G acc: 0.031]
5972 [D loss: (0.428)(R 0.425, F 0.431)] [D acc: (0.797)(0.781, 0.812)] [G loss: 1.683] [G acc: 0.062]
5973 [D loss: (0.512)(R 0.573, F 0.451)] [D acc: (0.750)(0.672, 0.828)] [G loss: 1.656] [G acc: 0.078]
5974 [D loss: (0.580)(R 0.719, F 0.440)] [D acc: (0.727)(0.609, 0.844)] [G loss: 1.362] [G acc: 0.172]
5975 [D loss: (0.538)(R 0.509, F 0.567)] [D acc: (0.703)(0.719, 0.688)] [G loss: 1.489] [G acc: 0.109]
5976 [D loss: (0.476)(R 0.570, F 0.382)] [D acc: (0.789)(0.672, 0.906)] [G loss: 1.592] [G acc: 0.031]
5977 [D loss: (0.485)(R 0.429, F 0.540)] [D acc: (0.758)(0.766, 0.750)] [G loss: 1.480] [G acc: 0.094]
5978 [D loss: (0.517)(R 0.544, F 0.490)] [D acc: (0.734)(0.672, 0.797)] [G loss: 1.528] [G acc: 0.062]
5979 [D loss: (0.561)(R 0.656, F 0.465)] [D acc: (0.680)(0.609, 0.750)] [G loss: 1.464] [G acc: 0.078]
5980 [D loss: (0.469)(R 0.509, F 0.430)] [D acc: (0.797)(0.734, 0.859)] [G loss: 1.584] [G acc: 0.062]
5981 [D loss: (0.513)(R 0.481, F 0.545)] [D acc: (0.719)(0.672, 0.766)] [G loss: 1.510] [G acc: 0.094]
5982 [D loss: (0.495)(R 0.575, F 0.415)] [D acc: (0.734)(0.625, 0.844)] [G loss: 1.469] [G acc: 0.109]
5983 [D loss: (0.525)(R 0.504, F 0.547)] [D acc: (0.773)(0.703, 0.844)] [G loss: 1.502] [G acc: 0.109]
5984 [D loss: (0.452)(R 0.478, F 0.427)] [D acc: (0.758)(0.688, 0.828)] [G loss: 1.471] [G acc: 0.125]
5985 [D loss: (0.551)(R 0.629, F 0.473)] [D acc: (0.766)(0.656, 0.875)] [G loss: 1.390] [G acc: 0.094]
5986 [D loss: (0.569)(R 0.511, F 0.627)] [D acc: (0.695)(0.703, 0.688)] [G loss: 1.601] [G acc: 0.047]
5987 [D loss: (0.538)(R 0.671, F 0.405)] [D acc: (0.734)(0.578, 0.891)] [G loss: 1.389] [G acc: 0.109]
5988 [D loss: (0.539)(R 0.529, F 0.550)] [D acc: (0.688)(0.656, 0.719)] [G loss: 1.338] [G acc: 0.125]
5989 [D loss: (0.594)(R 0.595, F 0.594)] [D acc: (0.703)(0.578, 0.828)] [G loss: 1.388] [G acc: 0.141]
5990 [D loss: (0.466)(R 0.484, F 0.447)] [D acc: (0.773)(0.703, 0.844)] [G loss: 1.316] [G acc: 0.203]
5991 [D loss: (0.596)(R 0.622, F 0.570)] [D acc: (0.656)(0.578, 0.734)] [G loss: 1.423] [G acc: 0.094]
5992 [D loss: (0.493)(R 0.483, F 0.502)] [D acc: (0.820)(0.734, 0.906)] [G loss: 1.557] [G acc: 0.078]
5993 [D loss: (0.492)(R 0.524, F 0.460)] [D acc: (0.766)(0.703, 0.828)] [G loss: 1.436] [G acc: 0.125]
5994 [D loss: (0.504)(R 0.539, F 0.468)] [D acc: (0.750)(0.672, 0.828)] [G loss: 1.558] [G acc: 0.047]
5995 [D loss: (0.464)(R 0.397, F 0.531)] [D acc: (0.820)(0.781, 0.859)] [G loss: 1.580] [G acc: 0.047]
5996 [D loss: (0.554)(R 0.602, F 0.506)] [D acc: (0.703)(0.625, 0.781)] [G loss: 1.404] [G acc: 0.062]
5997 [D loss: (0.509)(R 0.567, F 0.451)] [D acc: (0.758)(0.656, 0.859)] [G loss: 1.435] [G acc: 0.109]
5998 [D loss: (0.569)(R 0.416, F 0.721)] [D acc: (0.719)(0.781, 0.656)] [G loss: 1.422] [G acc: 0.094]
5999 [D loss: (0.493)(R 0.584, F 0.403)] [D acc: (0.781)(0.672, 0.891)] [G loss: 1.373] [G acc: 0.109]

 

fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot([x[0] for x in gan.g_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)
plt.ylim(0, 2)

plt.show()

fig = plt.figure()
plt.plot([x[3] for x in gan.d_losses], color='black', linewidth=0.25)
plt.plot([x[4] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[5] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot([x[1] for x in gan.g_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('accuracy', fontsize=16)

plt.xlim(0, 2000)

plt.show()

# epoch 20

# epoch 200

# epoch 400

# epoch 1000

# epoch 2000

https://broscoding.tistory.com/308

 

GAN 참고 문헌

http://book.naver.com/bookdb/book_detail.nhn?bid=15660741 미술관에 GAN 딥러닝 실전 프로젝트 창조에 다가서는 GAN의 4가지 생성 프로젝트이 책은 케라스를 사용한 딥러닝 기초부터 AI 분야 최신 알고리즘까지..

broscoding.tistory.com

 

반응형

'[AI] > GAN' 카테고리의 다른 글

WGAN.source  (0) 2020.07.15
WGAN.mode collapse(모드 붕괴)  (0) 2020.07.15
GAN.source  (0) 2020.07.15
WGAN.example.train.cifar  (0) 2020.07.14
upsampling(업샘플링)  (0) 2020.07.14
GAN이란...?  (0) 2020.07.08
정형 데이터와 비정형 데이터  (0) 2020.07.03
GAN 참고 문헌  (0) 2020.07.03
Comments