반응형
Notice
Recent Posts
Recent Comments
Link
관리 메뉴

bro's coding

WGANGP.train.faces 본문

[AI]/GAN

WGANGP.train.faces

givemebro 2020. 7. 15. 16:53
반응형
%matplotlib inline

import os
import matplotlib.pyplot as plt

from models.WGANGP import WGANGP
from utils.loaders import load_celeb

import pickle

https://broscoding.tistory.com/330

 

WGANGP.source

from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D from keras.layers.merge impo..

broscoding.tistory.com

# run params
SECTION = 'gan'
RUN_ID = '0003'
DATA_NAME = 'celeb'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

 

# 데이터 적재

BATCH_SIZE = 64
IMAGE_SIZE = 64

x_train = load_celeb(DATA_NAME, IMAGE_SIZE, BATCH_SIZE)


x_train[0][0][0]
array([[[-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        ...,
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ]],

       [[-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        ...,
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ]],

       [[-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        [-0.99215686, -0.99215686, -0.99215686],
        ...,
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ],
        [-0.9764706 , -0.9764706 , -0.9764706 ]],

       ...,

       [[ 0.05098039, -0.14509805, -0.46666667],
        [ 0.05098039, -0.14509805, -0.46666667],
        [ 0.05098039, -0.14509805, -0.46666667],
        ...,
        [ 0.04313726, -0.19215687, -0.49019608],
        [ 0.04313726, -0.19215687, -0.49019608],
        [ 0.05098039, -0.18431373, -0.48235294]],

       [[ 0.27058825,  0.07450981, -0.24705882],
        [ 0.27058825,  0.07450981, -0.24705882],
        [ 0.27058825,  0.07450981, -0.24705882],
        ...,
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.06666667, -0.16862746, -0.46666667]],

       [[ 0.3019608 ,  0.09803922, -0.19215687],
        [ 0.29411766,  0.09019608, -0.2       ],
        [ 0.3019608 ,  0.09803922, -0.19215687],
        ...,
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.05882353, -0.1764706 , -0.4745098 ],
        [ 0.05882353, -0.1764706 , -0.4745098 ]]], dtype=float32)

 

plt.imshow((x_train[0][0][0]+1)/2)

#create model

gan = WGANGP(input_dim = (IMAGE_SIZE,IMAGE_SIZE,3)
        , critic_conv_filters = [64,128,256,512]
        , critic_conv_kernel_size = [5,5,5,5]
        , critic_conv_strides = [2,2,2,2]
        , critic_batch_norm_momentum = None
        , critic_activation = 'leaky_relu'
        , critic_dropout_rate = None
        , critic_learning_rate = 0.0002
        , generator_initial_dense_layer_size = (4, 4, 512)
        , generator_upsample = [1,1,1,1]
        , generator_conv_filters = [256,128,64,3]
        , generator_conv_kernel_size = [5,5,5,5]
        , generator_conv_strides = [2,2,2,2]
        , generator_batch_norm_momentum = 0.9
        , generator_activation = 'leaky_relu'
        , generator_dropout_rate = None
        , generator_learning_rate = 0.0002
        , optimiser = 'adam'
        , grad_weight = 10
        , z_dim = 100
        , batch_size = BATCH_SIZE
        )

if mode == 'build':
    gan.save(RUN_FOLDER)

else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
    
    

 

gan.critic.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
critic_input (InputLayer)    (None, 64, 64, 3)         0         
_________________________________________________________________
critic_conv_0 (Conv2D)       (None, 32, 32, 64)        4864      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
critic_conv_1 (Conv2D)       (None, 16, 16, 128)       204928    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
critic_conv_2 (Conv2D)       (None, 8, 8, 256)         819456    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 8, 8, 256)         0         
_________________________________________________________________
critic_conv_3 (Conv2D)       (None, 4, 4, 512)         3277312   
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 4, 4, 512)         0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 8192)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 8193      
=================================================================
Total params: 4,314,753
Trainable params: 4,314,753
Non-trainable params: 0
_________________________________________________________________

 

gan.generator.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator_input (InputLayer) (None, 100)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 8192)              827392    
_________________________________________________________________
batch_normalization_1 (Batch (None, 8192)              32768     
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU)    (None, 8192)              0         
_________________________________________________________________
reshape_1 (Reshape)          (None, 4, 4, 512)         0         
_________________________________________________________________
generator_conv_0 (Conv2DTran (None, 8, 8, 256)         3277056   
_________________________________________________________________
batch_normalization_2 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 8, 8, 256)         0         
_________________________________________________________________
generator_conv_1 (Conv2DTran (None, 16, 16, 128)       819328    
_________________________________________________________________
batch_normalization_3 (Batch (None, 16, 16, 128)       512       
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
generator_conv_2 (Conv2DTran (None, 32, 32, 64)        204864    
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 64)        256       
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU)    (None, 32, 32, 64)        0         
_________________________________________________________________
generator_conv_3 (Conv2DTran (None, 64, 64, 3)         4803      
_________________________________________________________________
activation_1 (Activation)    (None, 64, 64, 3)         0         
=================================================================
Total params: 5,168,003
Trainable params: 5,150,723
Non-trainable params: 17,280
_________________________________________________________________

 

gan.critic_model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            (None, 100)          0                                            
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
model_2 (Model)                 (None, 64, 64, 3)    5168003     input_2[0][0]                    
__________________________________________________________________________________________________
random_weighted_average_1 (Rand (None, 64, 64, 3)    0           input_1[0][0]                    
                                                                 model_2[1][0]                    
__________________________________________________________________________________________________
model_1 (Model)                 (None, 1)            4314753     model_2[1][0]                    
                                                                 input_1[0][0]                    
                                                                 random_weighted_average_1[0][0]  
==================================================================================================
Total params: 4,332,033
Trainable params: 4,314,753
Non-trainable params: 17,280
__________________________________________________________________________________________________

 

 

# trian model

EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
N_CRITIC = 5
BATCH_SIZE = 64

gan.train(     
    x_train
    , batch_size = BATCH_SIZE
    , epochs = EPOCHS
    , run_folder = RUN_FOLDER
    , print_every_n_batches = PRINT_EVERY_N_BATCHES
    , n_critic = N_CRITIC
    , using_generator = True
)
0 (5, 1) [D loss: (1.0)(R -3.5, F -1.3, G 0.6)] [G loss: 2.5]
1 (5, 1) [D loss: (-66.7)(R -87.1, F -10.1, G 3.1)] [G loss: 6.4]
2 (5, 1) [D loss: (-122.9)(R -206.2, F 15.1, G 6.8)] [G loss: -14.2]
3 (5, 1) [D loss: (-120.7)(R -217.9, F 16.6, G 8.1)] [G loss: -11.7]
4 (5, 1) [D loss: (-135.9)(R -195.2, F 6.5, G 5.3)] [G loss: -18.7]
5 (5, 1) [D loss: (-130.0)(R -204.2, F 11.5, G 6.3)] [G loss: -16.2]
6 (5, 1) [D loss: (-141.2)(R -219.0, F 5.5, G 7.2)] [G loss: -29.8]
7 (5, 1) [D loss: (-140.4)(R -214.9, F 7.2, G 6.7)] [G loss: -20.4]
8 (5, 1) [D loss: (-136.3)(R -210.3, F 14.8, G 5.9)] [G loss: -24.1]
9 (5, 1) [D loss: (-130.8)(R -218.7, F 0.6, G 8.7)] [G loss: -5.2]
10 (5, 1) [D loss: (-114.4)(R -176.0, F 4.9, G 5.7)] [G loss: -8.8]
.
.
.
5990 (5, 1) [D loss: (-6.8)(R -2.3, F -5.1, G 0.1)] [G loss: 6.4]
5991 (5, 1) [D loss: (-6.3)(R -4.4, F -3.0, G 0.1)] [G loss: 0.7]
5992 (5, 1) [D loss: (-4.4)(R 4.4, F -10.0, G 0.1)] [G loss: 9.8]
5993 (5, 1) [D loss: (-7.3)(R -8.5, F 0.6, G 0.1)] [G loss: 0.2]
5994 (5, 1) [D loss: (-5.2)(R -7.2, F 1.3, G 0.1)] [G loss: -0.5]
5995 (5, 1) [D loss: (-4.2)(R -6.9, F 1.9, G 0.1)] [G loss: 1.4]
5996 (5, 1) [D loss: (-6.9)(R -7.6, F -0.2, G 0.1)] [G loss: -2.9]
5997 (5, 1) [D loss: (-6.1)(R -4.6, F -2.6, G 0.1)] [G loss: 3.6]
5998 (5, 1) [D loss: (-6.8)(R -0.9, F -6.7, G 0.1)] [G loss: 8.2]
5999 (5, 1) [D loss: (-5.2)(R -6.7, F 0.6, G 0.1)] [G loss: 3.4]

 

fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 2000)
# plt.ylim(0, 2)

plt.show()

https://broscoding.tistory.com/308

 

GAN 참고 문헌

http://book.naver.com/bookdb/book_detail.nhn?bid=15660741 미술관에 GAN 딥러닝 실전 프로젝트 창조에 다가서는 GAN의 4가지 생성 프로젝트이 책은 케라스를 사용한 딥러닝 기초부터 AI 분야 최신 알고리즘까지..

broscoding.tistory.com

 

반응형

'[AI] > GAN' 카테고리의 다른 글

WGANGP.source  (0) 2020.07.15
Threshold(임계치)  (0) 2020.07.15
GAN VS DCGAN  (0) 2020.07.15
WGAN.weight clipping(가중치 클리핑)  (0) 2020.07.15
WGAN.source  (0) 2020.07.15
WGAN.mode collapse(모드 붕괴)  (0) 2020.07.15
GAN.source  (0) 2020.07.15
WGAN.example.train.cifar  (0) 2020.07.14
Comments