반응형
Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- web 사진
- cudnn
- 대이터
- web 용어
- classification
- 머신러닝
- java역사
- vscode
- tensorflow
- web 개발
- bccard
- pycharm
- web
- 데이터전문기관
- 자료구조
- inorder
- C언어
- paragraph
- CES 2O21 참여
- 결합전문기관
- broscoding
- discrete_scatter
- Keras
- html
- CES 2O21 참가
- 재귀함수
- mglearn
- 웹 용어
- KNeighborsClassifier
- postorder
Archives
- Today
- Total
bro's coding
WGAN.source 본문
반응형
from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout, ZeroPadding2D, UpSampling2D
from keras.layers.merge import _Merge
from keras.models import Model, Sequential
from keras import backend as K
from keras.optimizers import Adam, RMSprop
from keras.callbacks import ModelCheckpoint
from keras.utils import plot_model
from keras.initializers import RandomNormal
import numpy as np
import json
import os
import pickle
import matplotlib.pyplot as plt
class WGAN():
def __init__(self
, input_dim
, critic_conv_filters
, critic_conv_kernel_size
, critic_conv_strides
, critic_batch_norm_momentum
, critic_activation
, critic_dropout_rate
, critic_learning_rate
, generator_initial_dense_layer_size
, generator_upsample
, generator_conv_filters
, generator_conv_kernel_size
, generator_conv_strides
, generator_batch_norm_momentum
, generator_activation
, generator_dropout_rate
, generator_learning_rate
, optimiser
, z_dim
):
self.name = 'gan'
self.input_dim = input_dim
self.critic_conv_filters = critic_conv_filters
self.critic_conv_kernel_size = critic_conv_kernel_size
self.critic_conv_strides = critic_conv_strides
self.critic_batch_norm_momentum = critic_batch_norm_momentum
self.critic_activation = critic_activation
self.critic_dropout_rate = critic_dropout_rate
self.critic_learning_rate = critic_learning_rate
self.generator_initial_dense_layer_size = generator_initial_dense_layer_size
self.generator_upsample = generator_upsample
self.generator_conv_filters = generator_conv_filters
self.generator_conv_kernel_size = generator_conv_kernel_size
self.generator_conv_strides = generator_conv_strides
self.generator_batch_norm_momentum = generator_batch_norm_momentum
self.generator_activation = generator_activation
self.generator_dropout_rate = generator_dropout_rate
self.generator_learning_rate = generator_learning_rate
self.optimiser = optimiser
self.z_dim = z_dim
self.n_layers_critic = len(critic_conv_filters)
self.n_layers_generator = len(generator_conv_filters)
self.weight_init = RandomNormal(mean=0., stddev=0.02)
self.d_losses = []
self.g_losses = []
self.epoch = 0
self._build_critic()
self._build_generator()
self._build_adversarial()
def wasserstein(self, y_true, y_pred):
return - K.mean(y_true * y_pred)
def get_activation(self, activation):
if activation == 'leaky_relu':
layer = LeakyReLU(alpha = 0.2)
else:
layer = Activation(activation)
return layer
def _build_critic(self):
### THE critic
critic_input = Input(shape=self.input_dim, name='critic_input')
x = critic_input
for i in range(self.n_layers_critic):
x = Conv2D(
filters = self.critic_conv_filters[i]
, kernel_size = self.critic_conv_kernel_size[i]
, strides = self.critic_conv_strides[i]
, padding = 'same'
, name = 'critic_conv_' + str(i)
, kernel_initializer = self.weight_init
)(x)
if self.critic_batch_norm_momentum and i > 0:
x = BatchNormalization(momentum = self.critic_batch_norm_momentum)(x)
x = self.get_activation(self.critic_activation)(x)
if self.critic_dropout_rate:
x = Dropout(rate = self.critic_dropout_rate)(x)
x = Flatten()(x)
# x = Dense(512, kernel_initializer = self.weight_init)(x)
# x = self.get_activation(self.critic_activation)(x)
critic_output = Dense(1, activation=None
, kernel_initializer = self.weight_init
)(x)
self.critic = Model(critic_input, critic_output)
def _build_generator(self):
### THE generator
generator_input = Input(shape=(self.z_dim,), name='generator_input')
x = generator_input
x = Dense(np.prod(self.generator_initial_dense_layer_size)
,kernel_initializer = self.weight_init
)(x)
if self.generator_batch_norm_momentum:
x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)
x = self.get_activation(self.generator_activation)(x)
x = Reshape(self.generator_initial_dense_layer_size)(x)
if self.generator_dropout_rate:
x = Dropout(rate = self.generator_dropout_rate)(x)
for i in range(self.n_layers_generator):
if self.generator_upsample[i] == 2:
x = UpSampling2D()(x)
x = Conv2D(
filters = self.generator_conv_filters[i]
, kernel_size = self.generator_conv_kernel_size[i]
, padding = 'same'
, name = 'generator_conv_' + str(i)
, kernel_initializer = self.weight_init
)(x)
else:
x = Conv2DTranspose(
filters = self.generator_conv_filters[i]
, kernel_size = self.generator_conv_kernel_size[i]
, padding = 'same'
, strides = self.generator_conv_strides[i]
, name = 'generator_conv_' + str(i)
, kernel_initializer = self.weight_init
)(x)
if i < self.n_layers_generator - 1:
if self.generator_batch_norm_momentum:
x = BatchNormalization(momentum = self.generator_batch_norm_momentum)(x)
x = self.get_activation(self.generator_activation)(x)
else:
x = Activation('tanh')(x)
generator_output = x
self.generator = Model(generator_input, generator_output)
def get_opti(self, lr):
if self.optimiser == 'adam':
opti = Adam(lr=lr, beta_1=0.5)
elif self.optimiser == 'rmsprop':
opti = RMSprop(lr=lr)
else:
opti = Adam(lr=lr)
return opti
def set_trainable(self, m, val):
m.trainable = val
for l in m.layers:
l.trainable = val
def _build_adversarial(self):
### COMPILE critic
self.critic.compile(
optimizer=self.get_opti(self.critic_learning_rate)
, loss = self.wasserstein
)
### COMPILE THE FULL GAN
self.set_trainable(self.critic, False)
model_input = Input(shape=(self.z_dim,), name='model_input')
model_output = self.critic(self.generator(model_input))
self.model = Model(model_input, model_output)
self.model.compile(
optimizer=self.get_opti(self.generator_learning_rate)
, loss=self.wasserstein
)
self.set_trainable(self.critic, True)
def train_critic(self, x_train, batch_size, clip_threshold, using_generator):
valid = np.ones((batch_size,1))
fake = -np.ones((batch_size,1))
if using_generator:
true_imgs = next(x_train)[0]
if true_imgs.shape[0] != batch_size:
true_imgs = next(x_train)[0]
else:
idx = np.random.randint(0, x_train.shape[0], batch_size)
true_imgs = x_train[idx]
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
gen_imgs = self.generator.predict(noise)
d_loss_real = self.critic.train_on_batch(true_imgs, valid)
d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * (d_loss_real + d_loss_fake)
for l in self.critic.layers:
weights = l.get_weights()
weights = [np.clip(w, -clip_threshold, clip_threshold) for w in weights]
l.set_weights(weights)
# for l in self.critic.layers:
# weights = l.get_weights()
# if 'batch_normalization' in l.get_config()['name']:
# pass
# # weights = [np.clip(w, -0.01, 0.01) for w in weights[:2]] + weights[2:]
# else:
# weights = [np.clip(w, -0.01, 0.01) for w in weights]
# l.set_weights(weights)
return [d_loss, d_loss_real, d_loss_fake]
def train_generator(self, batch_size):
valid = np.ones((batch_size,1))
noise = np.random.normal(0, 1, (batch_size, self.z_dim))
return self.model.train_on_batch(noise, valid)
def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches = 10
, n_critic = 5
, clip_threshold = 0.01
, using_generator = False):
for epoch in range(self.epoch, self.epoch + epochs):
for _ in range(n_critic):
d_loss = self.train_critic(x_train, batch_size, clip_threshold, using_generator)
g_loss = self.train_generator(batch_size)
# Plot the progress
print ("%d [D loss: (%.3f)(R %.3f, F %.3f)] [G loss: %.3f] " % (epoch, d_loss[0], d_loss[1], d_loss[2], g_loss))
self.d_losses.append(d_loss)
self.g_losses.append(g_loss)
# If at save interval => save generated image samples
if epoch % print_every_n_batches == 0:
self.sample_images(run_folder)
self.model.save_weights(os.path.join(run_folder, 'weights/weights-%d.h5' % (epoch)))
self.model.save_weights(os.path.join(run_folder, 'weights/weights.h5'))
self.save_model(run_folder)
self.epoch+=1
def sample_images(self, run_folder):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.z_dim))
gen_imgs = self.generator.predict(noise)
#Rescale images 0 - 1
gen_imgs = 0.5 * (gen_imgs + 1)
gen_imgs = np.clip(gen_imgs, 0, 1)
fig, axs = plt.subplots(r, c, figsize=(15,15))
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(np.squeeze(gen_imgs[cnt, :,:,:]), cmap = 'gray_r')
axs[i,j].axis('off')
cnt += 1
fig.savefig(os.path.join(run_folder, "images/sample_%d.png" % self.epoch))
plt.close()
def plot_model(self, run_folder):
plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes = True, show_layer_names = True)
plot_model(self.critic, to_file=os.path.join(run_folder ,'viz/critic.png'), show_shapes = True, show_layer_names = True)
plot_model(self.generator, to_file=os.path.join(run_folder ,'viz/generator.png'), show_shapes = True, show_layer_names = True)
def save(self, folder):
with open(os.path.join(folder, 'params.pkl'), 'wb') as f:
pickle.dump([
self.input_dim
, self.critic_conv_filters
, self.critic_conv_kernel_size
, self.critic_conv_strides
, self.critic_batch_norm_momentum
, self.critic_activation
, self.critic_dropout_rate
, self.critic_learning_rate
, self.generator_initial_dense_layer_size
, self.generator_upsample
, self.generator_conv_filters
, self.generator_conv_kernel_size
, self.generator_conv_strides
, self.generator_batch_norm_momentum
, self.generator_activation
, self.generator_dropout_rate
, self.generator_learning_rate
, self.optimiser
, self.z_dim
], f)
self.plot_model(folder)
def save_model(self, run_folder):
self.model.save(os.path.join(run_folder, 'model.h5'))
self.critic.save(os.path.join(run_folder, 'critic.h5'))
self.generator.save(os.path.join(run_folder, 'generator.h5'))
pickle.dump(self, open( os.path.join(run_folder, "obj.pkl"), "wb" ))
def load_weights(self, filepath):
self.model.load_weights(filepath)
https://broscoding.tistory.com/308
반응형
'[AI] > GAN' 카테고리의 다른 글
WGANGP.train.faces (0) | 2020.07.15 |
---|---|
Threshold(임계치) (0) | 2020.07.15 |
GAN VS DCGAN (0) | 2020.07.15 |
WGAN.weight clipping(가중치 클리핑) (0) | 2020.07.15 |
WGAN.mode collapse(모드 붕괴) (0) | 2020.07.15 |
GAN.source (0) | 2020.07.15 |
WGAN.example.train.cifar (0) | 2020.07.14 |
GAN.train.camel (0) | 2020.07.14 |
Comments