반응형
Notice
Recent Posts
Recent Comments
Link
관리 메뉴

bro's coding

WGAN.example.train.cifar 본문

[AI]/GAN

WGAN.example.train.cifar

givemebro 2020. 7. 14. 15:23
반응형

WGAN : wasserstein 손실함수를 GAN에 적용(원래 GAN의 손실 함수는 binary cross entropy)

import os
import numpy as np
import matplotlib.pyplot as plt

from models.WGAN import WGAN
from utils.loaders import load_cifar

 

# run params
SECTION = 'gan'
RUN_ID = '0002'
DATA_NAME = 'horses'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

 

 

# 데이터 적재


if DATA_NAME == 'cars':
    label = 1
elif DATA_NAME == 'horses':
    label = 7
(x_train, y_train) = load_cifar(label, 10)

 

plt.imshow((x_train[150,:,:,:]+1)/2)

# 모델 생성

if mode == 'build':

    gan = WGAN(input_dim = (32,32,3)
            , critic_conv_filters = [32,64,128,128]
            , critic_conv_kernel_size = [5,5,5,5]
            , critic_conv_strides = [2,2,2,1]
            , critic_batch_norm_momentum = None
            , critic_activation = 'leaky_relu'
            , critic_dropout_rate = None
            , critic_learning_rate = 0.00005
            , generator_initial_dense_layer_size = (4, 4, 128)
            , generator_upsample = [2,2, 2,1]
            , generator_conv_filters = [128,64,32,3]
            , generator_conv_kernel_size = [5,5,5,5]
            , generator_conv_strides = [1,1, 1,1]
            , generator_batch_norm_momentum = 0.8
            , generator_activation = 'leaky_relu'
            , generator_dropout_rate = None
            , generator_learning_rate = 0.00005
            , optimiser = 'rmsprop'
            , z_dim = 100
            )
    gan.save(RUN_FOLDER)

else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

 

gan.critic.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
critic_input (InputLayer)    (None, 32, 32, 3)         0         
_________________________________________________________________
critic_conv_0 (Conv2D)       (None, 16, 16, 32)        2432      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 32)        0         
_________________________________________________________________
critic_conv_1 (Conv2D)       (None, 8, 8, 64)          51264     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 64)          0         
_________________________________________________________________
critic_conv_2 (Conv2D)       (None, 4, 4, 128)         204928    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 4, 4, 128)         0         
_________________________________________________________________
critic_conv_3 (Conv2D)       (None, 4, 4, 128)         409728    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4, 4, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 2048)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 670,401
Trainable params: 670,401
Non-trainable params: 0
_________________________________________________________________

 

 

gan.generator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator_input (InputLayer) (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 2048)              206848    
_________________________________________________________________
batch_normalization_1 (Batch (None, 2048)              8192      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 2048)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 8, 8, 128)         0         
_________________________________________________________________
generator_conv_0 (Conv2D)    (None, 8, 8, 128)         409728    
_________________________________________________________________
batch_normalization_2 (Batch (None, 8, 8, 128)         512       
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 16, 16, 128)       0         
_________________________________________________________________
generator_conv_1 (Conv2D)    (None, 16, 16, 64)        204864    
_________________________________________________________________
batch_normalization_3 (Batch (None, 16, 16, 64)        256       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 16, 16, 64)        0         
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 32, 32, 64)        0         
_________________________________________________________________
generator_conv_2 (Conv2D)    (None, 32, 32, 32)        51232     
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 32)        128       
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 32, 32, 32)        0         
_________________________________________________________________
generator_conv_3 (Conv2DTran (None, 32, 32, 3)         2403      
_________________________________________________________________
activation_1 (Activation)    (None, 32, 32, 3)         0         
=================================================================
Total params: 884,163
Trainable params: 879,619
Non-trainable params: 4,544
_________________________________________________________________

 

# 모델 훈련
BATCH_SIZE = 128
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
N_CRITIC = 5
CLIP_THRESHOLD = 0.01
gan.train(
	x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
    , n_critic = N_CRITIC
    , clip_threshold = CLIP_THRESHOLD
)
0 [D loss: (-0.000)(R -0.001, F 0.000)]  [G loss: -0.000] 
1 [D loss: (0.000)(R -0.001, F 0.001)]  [G loss: -0.000] 
2 [D loss: (-0.000)(R -0.001, F 0.001)]  [G loss: -0.000] 
3 [D loss: (-0.000)(R -0.001, F 0.001)]  [G loss: -0.001] 
4 [D loss: (-0.000)(R -0.002, F 0.001)]  [G loss: -0.001] 
5 [D loss: (-0.000)(R -0.002, F 0.001)]  [G loss: -0.001] 
6 [D loss: (-0.001)(R -0.003, F 0.002)]  [G loss: -0.001] 
7 [D loss: (-0.001)(R -0.004, F 0.003)]  [G loss: -0.001] 
8 [D loss: (-0.000)(R -0.003, F 0.002)]  [G loss: -0.001] 
9 [D loss: (-0.000)(R -0.004, F 0.003)]  [G loss: -0.002] 
10 [D loss: (-0.001)(R -0.005, F 0.002)]  [G loss: -0.003] 
.
.
.
5990 [D loss: (0.003)(R -0.058, F 0.064)]  [G loss: -0.048] 
5991 [D loss: (-0.004)(R -0.070, F 0.061)]  [G loss: -0.052] 
5992 [D loss: (-0.002)(R -0.068, F 0.063)]  [G loss: -0.049] 
5993 [D loss: (-0.000)(R -0.055, F 0.055)]  [G loss: -0.042] 
5994 [D loss: (-0.011)(R -0.067, F 0.045)]  [G loss: -0.036] 
5995 [D loss: (-0.008)(R -0.065, F 0.049)]  [G loss: -0.045] 
5996 [D loss: (-0.012)(R -0.068, F 0.043)]  [G loss: -0.039] 
5997 [D loss: (-0.012)(R -0.070, F 0.045)]  [G loss: -0.038] 
5998 [D loss: (-0.008)(R -0.060, F 0.045)]  [G loss: -0.037] 
5999 [D loss: (-0.011)(R -0.070, F 0.049)]  [G loss: -0.043] 

 

gan.sample_images(RUN_FOLDER)
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

# plt.xlim(0, 2000)
# plt.ylim(0, 2)

plt.show()

def compare_images(img1, img2):
    return np.mean(np.abs(img1 - img2))
r, c = 5, 5

idx = np.random.randint(0, x_train.shape[0], BATCH_SIZE)
true_imgs = (x_train[idx] + 1) *0.5

fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        axs[i,j].imshow(true_imgs[cnt], cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/real.png"))
plt.show()

r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, gan.z_dim))
gen_imgs = gan.generator.predict(noise)

#Rescale images 0 - 1

gen_imgs = 0.5 * (gen_imgs + 1)
# gen_imgs = np.clip(gen_imgs, 0, 1)

fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/sample.png"))
plt.close()


fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        c_diff = 99999
        c_img = None
        for k_idx, k in enumerate((x_train + 1) * 0.5):
            
            diff = compare_images(gen_imgs[cnt, :,:,:], k)
            if diff < c_diff:
                c_img = np.copy(k)
                c_diff = diff
        axs[i,j].imshow(c_img, cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1

fig.savefig(os.path.join(RUN_FOLDER, "images/sample_closest.png"))
plt.show()

https://broscoding.tistory.com/308

 

GAN 참고 문헌

http://book.naver.com/bookdb/book_detail.nhn?bid=15660741 미술관에 GAN 딥러닝 실전 프로젝트 창조에 다가서는 GAN의 4가지 생성 프로젝트이 책은 케라스를 사용한 딥러닝 기초부터 AI 분야 최신 알고리즘까지..

broscoding.tistory.com

 

반응형

'[AI] > GAN' 카테고리의 다른 글

WGAN.weight clipping(가중치 클리핑)  (0) 2020.07.15
WGAN.source  (0) 2020.07.15
WGAN.mode collapse(모드 붕괴)  (0) 2020.07.15
GAN.source  (0) 2020.07.15
GAN.train.camel  (0) 2020.07.14
upsampling(업샘플링)  (0) 2020.07.14
GAN이란...?  (0) 2020.07.08
정형 데이터와 비정형 데이터  (0) 2020.07.03
Comments