반응형
Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
Tags
- 머신러닝
- classification
- pycharm
- web 용어
- 데이터전문기관
- tensorflow
- web 사진
- bccard
- paragraph
- CES 2O21 참가
- 재귀함수
- inorder
- KNeighborsClassifier
- 자료구조
- vscode
- C언어
- CES 2O21 참여
- postorder
- discrete_scatter
- Keras
- 웹 용어
- 대이터
- web
- java역사
- html
- cudnn
- web 개발
- 결합전문기관
- mglearn
- broscoding
Archives
- Today
- Total
bro's coding
WGAN.example.train.cifar 본문
반응형
WGAN : wasserstein 손실함수를 GAN에 적용(원래 GAN의 손실 함수는 binary cross entropy)
import os
import numpy as np
import matplotlib.pyplot as plt
from models.WGAN import WGAN
from utils.loaders import load_cifar
# run params
SECTION = 'gan'
RUN_ID = '0002'
DATA_NAME = 'horses'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])
if not os.path.exists(RUN_FOLDER):
os.mkdir(RUN_FOLDER)
os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
os.mkdir(os.path.join(RUN_FOLDER, 'images'))
os.mkdir(os.path.join(RUN_FOLDER, 'weights'))
mode = 'build' #'load' #
# 데이터 적재
if DATA_NAME == 'cars':
label = 1
elif DATA_NAME == 'horses':
label = 7
(x_train, y_train) = load_cifar(label, 10)
plt.imshow((x_train[150,:,:,:]+1)/2)
# 모델 생성
if mode == 'build':
gan = WGAN(input_dim = (32,32,3)
, critic_conv_filters = [32,64,128,128]
, critic_conv_kernel_size = [5,5,5,5]
, critic_conv_strides = [2,2,2,1]
, critic_batch_norm_momentum = None
, critic_activation = 'leaky_relu'
, critic_dropout_rate = None
, critic_learning_rate = 0.00005
, generator_initial_dense_layer_size = (4, 4, 128)
, generator_upsample = [2,2, 2,1]
, generator_conv_filters = [128,64,32,3]
, generator_conv_kernel_size = [5,5,5,5]
, generator_conv_strides = [1,1, 1,1]
, generator_batch_norm_momentum = 0.8
, generator_activation = 'leaky_relu'
, generator_dropout_rate = None
, generator_learning_rate = 0.00005
, optimiser = 'rmsprop'
, z_dim = 100
)
gan.save(RUN_FOLDER)
else:
gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))
gan.critic.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
critic_input (InputLayer) (None, 32, 32, 3) 0
_________________________________________________________________
critic_conv_0 (Conv2D) (None, 16, 16, 32) 2432
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 16, 16, 32) 0
_________________________________________________________________
critic_conv_1 (Conv2D) (None, 8, 8, 64) 51264
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 8, 8, 64) 0
_________________________________________________________________
critic_conv_2 (Conv2D) (None, 4, 4, 128) 204928
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 4, 4, 128) 0
_________________________________________________________________
critic_conv_3 (Conv2D) (None, 4, 4, 128) 409728
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 4, 4, 128) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 2048) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 2049
=================================================================
Total params: 670,401
Trainable params: 670,401
Non-trainable params: 0
_________________________________________________________________
gan.generator.summary()
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
generator_input (InputLayer) (None, 100) 0
_________________________________________________________________
dense_2 (Dense) (None, 2048) 206848
_________________________________________________________________
batch_normalization_1 (Batch (None, 2048) 8192
_________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 2048) 0
_________________________________________________________________
reshape_1 (Reshape) (None, 4, 4, 128) 0
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 8, 8, 128) 0
_________________________________________________________________
generator_conv_0 (Conv2D) (None, 8, 8, 128) 409728
_________________________________________________________________
batch_normalization_2 (Batch (None, 8, 8, 128) 512
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 8, 8, 128) 0
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 16, 16, 128) 0
_________________________________________________________________
generator_conv_1 (Conv2D) (None, 16, 16, 64) 204864
_________________________________________________________________
batch_normalization_3 (Batch (None, 16, 16, 64) 256
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 16, 16, 64) 0
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 32, 32, 64) 0
_________________________________________________________________
generator_conv_2 (Conv2D) (None, 32, 32, 32) 51232
_________________________________________________________________
batch_normalization_4 (Batch (None, 32, 32, 32) 128
_________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 32, 32, 32) 0
_________________________________________________________________
generator_conv_3 (Conv2DTran (None, 32, 32, 3) 2403
_________________________________________________________________
activation_1 (Activation) (None, 32, 32, 3) 0
=================================================================
Total params: 884,163
Trainable params: 879,619
Non-trainable params: 4,544
_________________________________________________________________
# 모델 훈련
BATCH_SIZE = 128
EPOCHS = 6000
PRINT_EVERY_N_BATCHES = 5
N_CRITIC = 5
CLIP_THRESHOLD = 0.01
gan.train(
x_train
, batch_size = BATCH_SIZE
, epochs = EPOCHS
, run_folder = RUN_FOLDER
, print_every_n_batches = PRINT_EVERY_N_BATCHES
, n_critic = N_CRITIC
, clip_threshold = CLIP_THRESHOLD
)
0 [D loss: (-0.000)(R -0.001, F 0.000)] [G loss: -0.000]
1 [D loss: (0.000)(R -0.001, F 0.001)] [G loss: -0.000]
2 [D loss: (-0.000)(R -0.001, F 0.001)] [G loss: -0.000]
3 [D loss: (-0.000)(R -0.001, F 0.001)] [G loss: -0.001]
4 [D loss: (-0.000)(R -0.002, F 0.001)] [G loss: -0.001]
5 [D loss: (-0.000)(R -0.002, F 0.001)] [G loss: -0.001]
6 [D loss: (-0.001)(R -0.003, F 0.002)] [G loss: -0.001]
7 [D loss: (-0.001)(R -0.004, F 0.003)] [G loss: -0.001]
8 [D loss: (-0.000)(R -0.003, F 0.002)] [G loss: -0.001]
9 [D loss: (-0.000)(R -0.004, F 0.003)] [G loss: -0.002]
10 [D loss: (-0.001)(R -0.005, F 0.002)] [G loss: -0.003]
.
.
.
5990 [D loss: (0.003)(R -0.058, F 0.064)] [G loss: -0.048]
5991 [D loss: (-0.004)(R -0.070, F 0.061)] [G loss: -0.052]
5992 [D loss: (-0.002)(R -0.068, F 0.063)] [G loss: -0.049]
5993 [D loss: (-0.000)(R -0.055, F 0.055)] [G loss: -0.042]
5994 [D loss: (-0.011)(R -0.067, F 0.045)] [G loss: -0.036]
5995 [D loss: (-0.008)(R -0.065, F 0.049)] [G loss: -0.045]
5996 [D loss: (-0.012)(R -0.068, F 0.043)] [G loss: -0.039]
5997 [D loss: (-0.012)(R -0.070, F 0.045)] [G loss: -0.038]
5998 [D loss: (-0.008)(R -0.060, F 0.045)] [G loss: -0.037]
5999 [D loss: (-0.011)(R -0.070, F 0.049)] [G loss: -0.043]
gan.sample_images(RUN_FOLDER)
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)
plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot(gan.g_losses, color='orange', linewidth=0.25)
plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)
# plt.xlim(0, 2000)
# plt.ylim(0, 2)
plt.show()
def compare_images(img1, img2):
return np.mean(np.abs(img1 - img2))
r, c = 5, 5
idx = np.random.randint(0, x_train.shape[0], BATCH_SIZE)
true_imgs = (x_train[idx] + 1) *0.5
fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(true_imgs[cnt], cmap = 'gray_r')
axs[i,j].axis('off')
cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/real.png"))
plt.show()
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, gan.z_dim))
gen_imgs = gan.generator.predict(noise)
#Rescale images 0 - 1
gen_imgs = 0.5 * (gen_imgs + 1)
# gen_imgs = np.clip(gen_imgs, 0, 1)
fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray_r')
axs[i,j].axis('off')
cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/sample.png"))
plt.close()
fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0
for i in range(r):
for j in range(c):
c_diff = 99999
c_img = None
for k_idx, k in enumerate((x_train + 1) * 0.5):
diff = compare_images(gen_imgs[cnt, :,:,:], k)
if diff < c_diff:
c_img = np.copy(k)
c_diff = diff
axs[i,j].imshow(c_img, cmap = 'gray_r')
axs[i,j].axis('off')
cnt += 1
fig.savefig(os.path.join(RUN_FOLDER, "images/sample_closest.png"))
plt.show()
https://broscoding.tistory.com/308
반응형
'[AI] > GAN' 카테고리의 다른 글
WGAN.weight clipping(가중치 클리핑) (0) | 2020.07.15 |
---|---|
WGAN.source (0) | 2020.07.15 |
WGAN.mode collapse(모드 붕괴) (0) | 2020.07.15 |
GAN.source (0) | 2020.07.15 |
GAN.train.camel (0) | 2020.07.14 |
upsampling(업샘플링) (0) | 2020.07.14 |
GAN이란...? (0) | 2020.07.08 |
정형 데이터와 비정형 데이터 (0) | 2020.07.03 |
Comments