[AI]/GAN

WGANGP.train.faces

givemebro 2020. 7. 15. 16:53
반응형
%matplotlib inline

import os
import matplotlib.pyplot as plt

from models.WGANGP import WGANGP
from utils.loaders import load_celeb

import pickle

https://broscoding.tistory.com/330

 

WGANGP.source

from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D from keras.layers.merge impo..

broscoding.tistory.com

# run params
SECTION = 'gan'
RUN_ID = '0003'
DATA_NAME = 'celeb'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

 

# 데이터 적재

BATCH_SIZE = 64
IMAGE_SIZE = 64

x_train = load_celeb(DATA_NAME, IMAGE_SIZE, BATCH_SIZE)


x_train[0][0][0]
array([[[-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        ...,
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ]],

       [[-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        ...,
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ]],

       [[-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        ...,
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ]],

       ...,

       [[ 0.05098039, -0.14509805, -0.46666667],
        [ 0.05098039, -0.14509805, -0.46666667],
        [ 0.05098039, -0.14509805, -0.46666667],
        ...,
        [ 0.04313726, -0.19215687, -0.49019608],
        [ 0.04313726, -0.19215687, -0.49019608],
        [ 0.05098039, -0.18431373, -0.48235294]],

       [[ 0.27058825,  0.07450981, -0.24705882],
        [ 0.27058825,  0.07450981, -0.24705882],
        [ 0.27058825,  0.07450981, -0.24705882],
        ...,
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.06666667, -0.16862746, -0.46666667]],

       [[ 0.3019608 ,  0.09803922, -0.19215687],
        [ 0.29411766,  0.09019608, -0.2       ],
        [ 0.3019608 ,  0.09803922, -0.19215687],
        ...,
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.05882353, -0.1764706 , -0.4745098 ]]], dtype=float32)

 

plt.imshow((x_train[0][0][0]+1)/2)

#create model

gan = WGANGP(input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)
        , critic_conv_filters = [64,128,256,512]
        , critic_conv_kernel_size = [5,5,5,5]
        , critic_conv_strides = [2,2,2,2]
        , critic_batch_norm_momentum = None
        , critic_activation = 'leaky_relu'
        , critic_dropout_rate = None
        , critic_learning_rate = 0.0002
        , generator_initial_dense_layer_size = (4, 4, 512)
        , generator_upsample = [1,1,1,1]
        , generator_conv_filters = [256,128,64,3]
        , generator_conv_kernel_size = [5,5,5,5]
        , generator_conv_strides = [2,2,2,2]
        , generator_batch_norm_momentum = 0.9
        , generator_activation = 'leaky_relu'
        , generator_dropout_rate = None
        , generator_learning_rate = 0.0002
        , optimiser = 'adam'
        , grad_weight = 10
        , z_dim = 100
        , batch_size = BATCH_SIZE
        )

if mode == 'build':
    gan.save(RUN_FOLDER)

else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
    
    

 

gan.critic.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
critic_input (InputLayer)    (None, 64, 64, 3)         0         
_________________________________________________________________
critic_conv_0 (Conv2D)       (None, 32, 32, 64)        4864      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
critic_conv_1 (Conv2D)       (None, 16, 16, 128)       204928    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
critic_conv_2 (Conv2D)       (None, 8, 8, 256)         819456    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 8, 8, 256)         0         
_________________________________________________________________
critic_conv_3 (Conv2D)       (None, 4, 4, 512)         3277312   
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4, 4, 512)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 8192)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 8193      
=================================================================
Total params: 4,314,753
Trainable params: 4,314,753
Non-trainable params: 0
_________________________________________________________________

 

gan.generator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator_input (InputLayer) (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 8192)              827392    
_________________________________________________________________
batch_normalization_1 (Batch (None, 8192)              32768     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 8192)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 512)         0         
_________________________________________________________________
generator_conv_0 (Conv2DTran (None, 8, 8, 256)         3277056   
_________________________________________________________________
batch_normalization_2 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 8, 8, 256)         0         
_________________________________________________________________
generator_conv_1 (Conv2DTran (None, 16, 16, 128)       819328    
_________________________________________________________________
batch_normalization_3 (Batch (None, 16, 16, 128)       512       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
generator_conv_2 (Conv2DTran (None, 32, 32, 64)        204864    
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
generator_conv_3 (Conv2DTran (None, 64, 64, 3)         4803      
_________________________________________________________________
activation_1 (Activation)    (None, 64, 64, 3)         0         
=================================================================
Total params: 5,168,003
Trainable params: 5,150,723
Non-trainable params: 17,280
_________________________________________________________________

 

gan.critic_model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, 64, 64, 3)    5168003     input_2[0][0]                    
__________________________________________________________________________________________________
random_weighted_average_1 (Rand (None, 64, 64, 3)    0           input_1[0][0]                    
                                                                 model_2[1][0]                    
__________________________________________________________________________________________________
model_1 (Model)                 (None, 1)            4314753     model_2[1][0]                    
                                                                 input_1[0][0]                    
                                                                 random_weighted_average_1[0][0]  
==================================================================================================
Total params: 4,332,033
Trainable params: 4,314,753
Non-trainable params: 17,280
__________________________________________________________________________________________________

 

 

# trian model

EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
N_CRITIC = 5
BATCH_SIZE = 64

gan.train(     
    x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
    , n_critic = N_CRITIC
    , using_generator = True
)
0 (5, 1) [D loss: (1.0)(R -3.5, F -1.3, G 0.6)] [G loss: 2.5]
1 (5, 1) [D loss: (-66.7)(R -87.1, F -10.1, G 3.1)] [G loss: 6.4]
2 (5, 1) [D loss: (-122.9)(R -206.2, F 15.1, G 6.8)] [G loss: -14.2]
3 (5, 1) [D loss: (-120.7)(R -217.9, F 16.6, G 8.1)] [G loss: -11.7]
4 (5, 1) [D loss: (-135.9)(R -195.2, F 6.5, G 5.3)] [G loss: -18.7]
5 (5, 1) [D loss: (-130.0)(R -204.2, F 11.5, G 6.3)] [G loss: -16.2]
6 (5, 1) [D loss: (-141.2)(R -219.0, F 5.5, G 7.2)] [G loss: -29.8]
7 (5, 1) [D loss: (-140.4)(R -214.9, F 7.2, G 6.7)] [G loss: -20.4]
8 (5, 1) [D loss: (-136.3)(R -210.3, F 14.8, G 5.9)] [G loss: -24.1]
9 (5, 1) [D loss: (-130.8)(R -218.7, F 0.6, G 8.7)] [G loss: -5.2]
10 (5, 1) [D loss: (-114.4)(R -176.0, F 4.9, G 5.7)] [G loss: -8.8]
.
.
.
5990 (5, 1) [D loss: (-6.8)(R -2.3, F -5.1, G 0.1)] [G loss: 6.4]
5991 (5, 1) [D loss: (-6.3)(R -4.4, F -3.0, G 0.1)] [G loss: 0.7]
5992 (5, 1) [D loss: (-4.4)(R 4.4, F -10.0, G 0.1)] [G loss: 9.8]
5993 (5, 1) [D loss: (-7.3)(R -8.5, F 0.6, G 0.1)] [G loss: 0.2]
5994 (5, 1) [D loss: (-5.2)(R -7.2, F 1.3, G 0.1)] [G loss: -0.5]
5995 (5, 1) [D loss: (-4.2)(R -6.9, F 1.9, G 0.1)] [G loss: 1.4]
5996 (5, 1) [D loss: (-6.9)(R -7.6, F -0.2, G 0.1)] [G loss: -2.9]
5997 (5, 1) [D loss: (-6.1)(R -4.6, F -2.6, G 0.1)] [G loss: 3.6]
5998 (5, 1) [D loss: (-6.8)(R -0.9, F -6.7, G 0.1)] [G loss: 8.2]
5999 (5, 1) [D loss: (-5.2)(R -6.7, F 0.6, G 0.1)] [G loss: 3.4]

 

fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)
# plt.ylim(0, 2)

plt.show()

https://broscoding.tistory.com/308

 

GAN 참고 문헌

http://book.naver.com/bookdb/book_detail.nhn?bid=15660741 미술관에 GAN 딥러닝 실전 프로젝트 창조에 다가서는 GAN의 4가지 생성 프로젝트이 책은 케라스를 사용한 딥러닝 기초부터 AI 분야 최신 알고리즘까지..

broscoding.tistory.com

 

반응형