[AI]/GAN

WGAN.example.train.cifar

givemebro 2020. 7. 14. 15:23
반응형

WGAN : wasserstein 손실함수를 GAN에 적용(원래 GAN의 손실 함수는 binary cross entropy)

import os
import numpy as np
import matplotlib.pyplot as plt

from models.WGAN import WGAN
from utils.loaders import load_cifar

 

# run params
SECTION = 'gan'
RUN_ID = '0002'
DATA_NAME = 'horses'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

 

 

# 데이터 적재


if DATA_NAME == 'cars':
    label = 1
elif DATA_NAME == 'horses':
    label = 7
(x_train, y_train) = load_cifar(label, 10)

 

plt.imshow((x_train[150,:,:,:]+1)/2)

# 모델 생성

if mode == 'build':

    gan = WGAN(input_dim = (32,32,3)
            , critic_conv_filters = [32,64,128,128]
            , critic_conv_kernel_size = [5,5,5,5]
            , critic_conv_strides = [2,2,2,1]
            , critic_batch_norm_momentum = None
            , critic_activation = 'leaky_relu'
            , critic_dropout_rate = None
            , critic_learning_rate = 0.00005
            , generator_initial_dense_layer_size = (4, 4, 128)
            , generator_upsample = [2,2, 2,1]
            , generator_conv_filters = [128,64,32,3]
            , generator_conv_kernel_size = [5,5,5,5]
            , generator_conv_strides = [1,1, 1,1]
            , generator_batch_norm_momentum = 0.8
            , generator_activation = 'leaky_relu'
            , generator_dropout_rate = None
            , generator_learning_rate = 0.00005
            , optimiser = 'rmsprop'
            , z_dim = 100
            )
    gan.save(RUN_FOLDER)

else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

 

gan.critic.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
critic_input (InputLayer)    (None, 32, 32, 3)         0         
_________________________________________________________________
critic_conv_0 (Conv2D)       (None, 16, 16, 32)        2432      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 32)        0         
_________________________________________________________________
critic_conv_1 (Conv2D)       (None, 8, 8, 64)          51264     
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 64)          0         
_________________________________________________________________
critic_conv_2 (Conv2D)       (None, 4, 4, 128)         204928    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 4, 4, 128)         0         
_________________________________________________________________
critic_conv_3 (Conv2D)       (None, 4, 4, 128)         409728    
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4, 4, 128)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 2048)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 670,401
Trainable params: 670,401
Non-trainable params: 0
_________________________________________________________________

 

 

gan.generator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator_input (InputLayer) (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 2048)              206848    
_________________________________________________________________
batch_normalization_1 (Batch (None, 2048)              8192      
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 2048)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 128)         0         
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 8, 8, 128)         0         
_________________________________________________________________
generator_conv_0 (Conv2D)    (None, 8, 8, 128)         409728    
_________________________________________________________________
batch_normalization_2 (Batch (None, 8, 8, 128)         512       
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 16, 16, 128)       0         
_________________________________________________________________
generator_conv_1 (Conv2D)    (None, 16, 16, 64)        204864    
_________________________________________________________________
batch_normalization_3 (Batch (None, 16, 16, 64)        256       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 16, 16, 64)        0         
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 32, 32, 64)        0         
_________________________________________________________________
generator_conv_2 (Conv2D)    (None, 32, 32, 32)        51232     
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 32)        128       
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 32, 32, 32)        0         
_________________________________________________________________
generator_conv_3 (Conv2DTran (None, 32, 32, 3)         2403      
_________________________________________________________________
activation_1 (Activation)    (None, 32, 32, 3)         0         
=================================================================
Total params: 884,163
Trainable params: 879,619
Non-trainable params: 4,544
_________________________________________________________________

 

# 모델 훈련
BATCH_SIZE = 128
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
N_CRITIC = 5
CLIP_THRESHOLD = 0.01
gan.train(
	x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
    , n_critic = N_CRITIC
    , clip_threshold = CLIP_THRESHOLD
)
0 [D loss: (-0.000)(R -0.001, F 0.000)]  [G loss: -0.000] 
1 [D loss: (0.000)(R -0.001, F 0.001)]  [G loss: -0.000] 
2 [D loss: (-0.000)(R -0.001, F 0.001)]  [G loss: -0.000] 
3 [D loss: (-0.000)(R -0.001, F 0.001)]  [G loss: -0.001] 
4 [D loss: (-0.000)(R -0.002, F 0.001)]  [G loss: -0.001] 
5 [D loss: (-0.000)(R -0.002, F 0.001)]  [G loss: -0.001] 
6 [D loss: (-0.001)(R -0.003, F 0.002)]  [G loss: -0.001] 
7 [D loss: (-0.001)(R -0.004, F 0.003)]  [G loss: -0.001] 
8 [D loss: (-0.000)(R -0.003, F 0.002)]  [G loss: -0.001] 
9 [D loss: (-0.000)(R -0.004, F 0.003)]  [G loss: -0.002] 
10 [D loss: (-0.001)(R -0.005, F 0.002)]  [G loss: -0.003] 
.
.
.
5990 [D loss: (0.003)(R -0.058, F 0.064)]  [G loss: -0.048] 
5991 [D loss: (-0.004)(R -0.070, F 0.061)]  [G loss: -0.052] 
5992 [D loss: (-0.002)(R -0.068, F 0.063)]  [G loss: -0.049] 
5993 [D loss: (-0.000)(R -0.055, F 0.055)]  [G loss: -0.042] 
5994 [D loss: (-0.011)(R -0.067, F 0.045)]  [G loss: -0.036] 
5995 [D loss: (-0.008)(R -0.065, F 0.049)]  [G loss: -0.045] 
5996 [D loss: (-0.012)(R -0.068, F 0.043)]  [G loss: -0.039] 
5997 [D loss: (-0.012)(R -0.070, F 0.045)]  [G loss: -0.038] 
5998 [D loss: (-0.008)(R -0.060, F 0.045)]  [G loss: -0.037] 
5999 [D loss: (-0.011)(R -0.070, F 0.049)]  [G loss: -0.043] 

 

gan.sample_images(RUN_FOLDER)
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

# plt.xlim(0, 2000)
# plt.ylim(0, 2)

plt.show()

def compare_images(img1, img2):
    return np.mean(np.abs(img1 - img2))
r, c = 5, 5

idx = np.random.randint(0, x_train.shape[0], BATCH_SIZE)
true_imgs = (x_train[idx] + 1) *0.5

fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        axs[i,j].imshow(true_imgs[cnt], cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/real.png"))
plt.show()

r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, gan.z_dim))
gen_imgs = gan.generator.predict(noise)

#Rescale images 0 - 1

gen_imgs = 0.5 * (gen_imgs + 1)
# gen_imgs = np.clip(gen_imgs, 0, 1)

fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/sample.png"))
plt.close()


fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0

for i in range(r):
    for j in range(c):
        c_diff = 99999
        c_img = None
        for k_idx, k in enumerate((x_train + 1) * 0.5):
            
            diff = compare_images(gen_imgs[cnt, :,:,:], k)
            if diff < c_diff:
                c_img = np.copy(k)
                c_diff = diff
        axs[i,j].imshow(c_img, cmap = 'gray_r')
        axs[i,j].axis('off')
        cnt += 1

fig.savefig(os.path.join(RUN_FOLDER, "images/sample_closest.png"))
plt.show()

https://broscoding.tistory.com/308

 

GAN 참고 문헌

http://book.naver.com/bookdb/book_detail.nhn?bid=15660741 미술관에 GAN 딥러닝 실전 프로젝트 창조에 다가서는 GAN의 4가지 생성 프로젝트이 책은 케라스를 사용한 딥러닝 기초부터 AI 분야 최신 알고리즘까지..

broscoding.tistory.com

 

반응형