diff --git a/train.py b/train.py index 1a9e074..e2f5443 100755 --- a/train.py +++ b/train.py @@ -1,146 +1,101 @@ #!/usr/bin/env python3 +import sys +import pickle + +import matplotlib.pyplot as plt import numpy as np -from keras import Model -from keras.datasets import mnist from keras.callbacks import TensorBoard +from keras.optimizers import RMSprop -from vaegan.models import create_models - -import cv2 - - -def main(): - encoder, decoder, discriminator, vae, vae_loss = create_models() - # - # encoder.compile('rmsprop', 'mse') - # - # x = np.random.uniform(-1.0, 1.0, size=[1, 64, 64, 1]) - # y1 = np.random.uniform(-1.0, 1.0, size=[1, 128]) - # y2 = np.random.uniform(-1.0, 1.0, size=[1, 128]) - # - # encoder.fit(x, [y1, y2], callbacks=[TensorBoard()]) - # - # return - - - - batch_size = 32 +from vaegan.models import create_models, build_graph +from vaegan.training import fit_models +from vaegan.data import celeba_loader, encoder_loader, decoder_loader, discriminator_loader, NUM_SAMPLES, mnist_loader +from vaegan.callbacks import DecoderSnapshot, ModelsCheckpoint - (x_train, y_train), (x_test, y_test) = mnist.load_data() - # Resize to 64x64 - x_train_new = np.zeros((x_train.shape[0], 64, 64), dtype='int32') - for i, img in enumerate(x_train): - x_train_new[i] = cv2.resize(img, (64, 64), interpolation=cv2.INTER_CUBIC) +def set_trainable(model, trainable): + model.trainable = trainable + for layer in model.layers: + layer.trainable = trainable - x_train = x_train_new - del x_train_new - - # Normalize to [-1, 1] - #x_train = np.pad(x_train, ((0, 0), (18, 18), (18, 18)), mode='constant', constant_values=0) - x_train = np.expand_dims(x_train, -1) - x_train = (x_train.astype('float32') - 127.5) / 127.5 - x_train = np.clip(x_train, -1., 1.) - - - # Assume images in x_train - # x_train = np.zeros((100, 64, 64, 3)) - - discriminator.compile('rmsprop', 'binary_crossentropy', ['accuracy']) - discriminator.trainable = False - - model = Model(vae.inputs, discriminator(vae.outputs), name='vaegan') - model.add_loss(vae_loss) - model.compile('rmsprop', 'binary_crossentropy', ['accuracy']) - - import keras.callbacks as cbks - import os.path - - verbose = True - checkpoint = cbks.ModelCheckpoint(os.path.join('.', 'model.{epoch:02d}.h5'), save_weights_only=True) +def main(): + encoder, decoder, discriminator = create_models() + encoder_train, decoder_train, discriminator_train, vae, vaegan = build_graph(encoder, decoder, discriminator) - callbacks = [TensorBoard(batch_size=batch_size), checkpoint] + try: + initial_epoch = int(sys.argv[1]) + except (IndexError, ValueError): + initial_epoch = 0 - epochs = 100 - steps_per_epoch = x_train.shape[0] // batch_size - do_validation = False + epoch_format = '.{epoch:03d}.h5' - callback_metrics = ['disc_loss', 'disc_accuracy', 'vaegan_loss', 'vaegan_accuracy'] + if initial_epoch != 0: + suffix = epoch_format.format(epoch=initial_epoch) + encoder.load_weights('encoder' + suffix) + decoder.load_weights('decoder' + suffix) + discriminator.load_weights('discriminator' + suffix) - model.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [model.history] - if verbose: - callbacks += [cbks.ProgbarLogger(count_mode='steps')] - callbacks = cbks.CallbackList(callbacks) + batch_size = 64 + rmsprop = RMSprop(lr=0.0003) - # it's possible to callback a different model than self: - if hasattr(model, 'callback_model') and model.callback_model: - callback_model = model.callback_model - else: - callback_model = model - callbacks.set_model(callback_model) - callbacks.set_params({ - 'epochs': epochs, - 'steps': steps_per_epoch, - 'verbose': verbose, - 'do_validation': do_validation, - 'metrics': callback_metrics, - }) - callbacks.on_train_begin() + set_trainable(encoder, False) + set_trainable(decoder, False) + discriminator_train.compile(rmsprop, ['binary_crossentropy'] * 3, ['acc'] * 3) + discriminator_train.summary() - epoch_logs = {} + set_trainable(discriminator, False) + set_trainable(decoder, True) + decoder_train.compile(rmsprop, ['binary_crossentropy'] * 2, ['acc'] * 2) + decoder_train.summary() - for epoch in range(epochs): + set_trainable(decoder, False) + set_trainable(encoder, True) + encoder_train.compile(rmsprop) + encoder_train.summary() - callbacks.on_epoch_begin(epoch) + set_trainable(vaegan, True) - for batch_index in range(steps_per_epoch): - batch_logs = {} - batch_logs['batch'] = batch_index - batch_logs['size'] = batch_size - callbacks.on_batch_begin(batch_index, batch_logs) + checkpoint = ModelsCheckpoint(epoch_format, encoder, decoder, discriminator) + decoder_sampler = DecoderSnapshot() + callbacks = [checkpoint, decoder_sampler, TensorBoard()] - rand_indexes = np.random.randint(0, x_train.shape[0], size=batch_size) - real_images = x_train[rand_indexes] + epochs = 250 - fake_images = vae.predict(real_images) - # print(fake_images.shape) - half_batch = batch_size // 2 - inputs = np.concatenate([real_images[:half_batch], fake_images[:half_batch]]) + steps_per_epoch = NUM_SAMPLES // batch_size - # Label real and fake images - y = np.ones([batch_size, 1], dtype='float32') - y[half_batch:, :] = 0 + seed = np.random.randint(2**32 - 1) - # Train the Discriminator network - metrics = discriminator.train_on_batch(inputs, y) - # print('discriminator', metrics) + img_loader = celeba_loader(batch_size, num_child=3, seed=seed) + dis_loader = discriminator_loader(img_loader, seed=seed) + dec_loader = decoder_loader(img_loader, seed=seed) + enc_loader = encoder_loader(img_loader) - y = np.ones([batch_size, 1], dtype='float32') - vg_metrics = model.train_on_batch(fake_images, y) - # print('full', metrics) + models = [discriminator_train, decoder_train, encoder_train] + generators = [dis_loader, dec_loader, enc_loader] + metrics = [{'di_l': 1, 'di_l_t': 2, 'di_l_p': 3, 'di_a': 4, 'di_a_t': 7, 'di_a_p': 10}, {'de_l_t': 1, 'de_l_p': 2, 'de_a_t': 3, 'de_a_p': 5}, {'en_l': 0}] - batch_logs['disc_loss'] = metrics[0] - batch_logs['disc_accuracy'] = metrics[1] - batch_logs['vaegan_loss'] = vg_metrics[0] - batch_logs['vaegan_accuracy'] = vg_metrics[1] + histories = fit_models(vaegan, models, generators, metrics, batch_size, + steps_per_epoch=steps_per_epoch, callbacks=callbacks, + epochs=epochs, initial_epoch=initial_epoch) - callbacks.on_batch_end(batch_index, batch_logs) + with open('histories.pickle', 'wb') as f: + pickle.dump(histories, f) - callbacks.on_epoch_end(epoch, epoch_logs) + x = next(celeba_loader(1)) - rand_indexes = np.random.randint(0, x_train.shape[0], size=1) - real_images = x_train[rand_indexes] + x_tilde = vae.predict(x) - model.save_weights('trained.h5') + plt.subplot(211) + plt.imshow((x[0].squeeze() + 1.) / 2.) - a = encoder.predict(real_images) - print(a) + plt.subplot(212) + plt.imshow((x_tilde[0].squeeze() + 1.) / 2.) + plt.show() if __name__ == '__main__': diff --git a/train_adagrad.py b/train_adagrad.py new file mode 100644 index 0000000..b31fd85 --- /dev/null +++ b/train_adagrad.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 + +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np + +from keras.callbacks import TensorBoard, ModelCheckpoint +from keras.optimizers import Adagrad + +from vaegan.models import create_models, build_graph +from vaegan.training import fit_models +from vaegan.data import celeba_loader, encoder_loader, decoder_loader, discriminator_loader, NUM_SAMPLES, mnist_loader +from vaegan.callbacks import DecoderSnapshot + + +def set_trainable(model, trainable): + model.trainable = trainable + for layer in model.layers: + layer.trainable = trainable + + +def main(): + encoder, decoder, discriminator = create_models() + encoder_train, decoder_train, discriminator_train, vae, vaegan = build_graph(encoder, decoder, discriminator) + + if len(sys.argv) == 3: + vaegan.load_weights(sys.argv[1]) + initial_epoch = int(sys.argv[2]) + else: + initial_epoch = 0 + + batch_size = 64 + + opt = Adagrad(lr=0.01, epsilon=None, decay=0.0) + + set_trainable(encoder, False) + set_trainable(decoder, False) + discriminator_train.compile(opt, ['binary_crossentropy'] * 3, ['acc'] * 3) + discriminator_train.summary() + + set_trainable(discriminator, False) + set_trainable(decoder, True) + decoder_train.compile(opt, ['binary_crossentropy'] * 2, ['acc'] * 2) + decoder_train.summary() + + set_trainable(decoder, False) + set_trainable(encoder, True) + encoder_train.compile(opt) + encoder_train.summary() + + set_trainable(vaegan, True) + + checkpoint = ModelCheckpoint(os.path.join('.', 'model.{epoch:02d}.h5'), save_weights_only=True) + decoder_sampler = DecoderSnapshot() + + callbacks = [checkpoint, decoder_sampler, TensorBoard()] + + epochs = 100 + + steps_per_epoch = NUM_SAMPLES // batch_size + + seed = np.random.randint(2**32 - 1) + + img_loader = celeba_loader(batch_size, num_child=3, seed=seed) + dis_loader = discriminator_loader(img_loader, seed=seed) + dec_loader = decoder_loader(img_loader, seed=seed) + enc_loader = encoder_loader(img_loader) + + models = [discriminator_train, decoder_train, encoder_train] + generators = [dis_loader, dec_loader, enc_loader] + metrics = [{'di_l': 1, 'di_l_t': 2, 'di_l_p': 3, 'di_a': 4, 'di_a_t': 7, 'di_a_p': 10}, {'de_l_t': 1, 'de_l_p': 2, 'de_a_t': 3, 'de_a_p': 5}, {'en_l': 0}] + + histories = fit_models(vaegan, models, generators, metrics, batch_size, + steps_per_epoch=steps_per_epoch, callbacks=callbacks, + epochs=epochs, initial_epoch=initial_epoch) + + vaegan.save_weights('trained.h5') + + x = next(celeba_loader(1)) + + x_tilde = vae.predict(x) + + plt.subplot(211) + plt.imshow((x[0].squeeze() + 1.) / 2.) + + plt.subplot(212) + plt.imshow((x_tilde[0].squeeze() + 1.) / 2.) + + plt.show() + + +if __name__ == '__main__': + main() diff --git a/train_dualgpu.py b/train_dualgpu.py new file mode 100644 index 0000000..e761a67 --- /dev/null +++ b/train_dualgpu.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 + +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np + +from keras.callbacks import TensorBoard, ModelCheckpoint +from keras.optimizers import RMSprop + +from vaegan.models import create_models, build_graph +from vaegan.training import fit_models +from vaegan.data import celeba_loader, encoder_loader, decoder_loader, discriminator_loader, NUM_SAMPLES, mnist_loader +from vaegan.callbacks import DecoderSnapshot + +from keras.utils import multi_gpu_model + +class ModelMGPU(Model): + def __init__(self, ser_model, gpus): + pmodel = multi_gpu_model(ser_model, gpus) + self.__dict__.update(pmodel.__dict__) + self._smodel = ser_model + + def __getattribute__(self, attrname): + '''Override load and save methods to be used from the serial-model. The + serial-model holds references to the weights in the multi-gpu model. + ''' + # return Model.__getattribute__(self, attrname) + if 'load' in attrname or 'save' in attrname: + return getattr(self._smodel, attrname) + + return super(ModelMGPU, self).__getattribute__(attrname) + +def set_trainable(model, trainable): + model.trainable = trainable + for layer in model.layers: + layer.trainable = trainable + + +def main(): + model_s = Model(inputs=inputs, outputs=outputs) + model = ModelMGPU(model_s, gpus=2) #try implementation + encoder, decoder, discriminator = create_models() + encoder_train, decoder_train, discriminator_train, vae, vaegan = build_graph(encoder, decoder, discriminator) + + if len(sys.argv) == 3: + vaegan.load_weights(sys.argv[1]) + initial_epoch = int(sys.argv[2]) + else: + initial_epoch = 0 + + batch_size = 64 + + rmsprop = RMSprop(lr=0.0003) + + set_trainable(encoder, False) + set_trainable(decoder, False) + discriminator_train.compile(rmsprop, ['binary_crossentropy'] * 3, ['acc'] * 3) + discriminator_train.summary() + + set_trainable(discriminator, False) + set_trainable(decoder, True) + decoder_train.compile(rmsprop, ['binary_crossentropy'] * 2, ['acc'] * 2) + decoder_train.summary() + + set_trainable(decoder, False) + set_trainable(encoder, True) + encoder_train.compile(rmsprop) + encoder_train.summary() + + set_trainable(vaegan, True) + + checkpoint = ModelCheckpoint(os.path.join('.', 'model.{epoch:02d}.h5'), save_weights_only=True) + decoder_sampler = DecoderSnapshot() + + callbacks = [checkpoint, decoder_sampler, TensorBoard()] + + epochs = 100 + + steps_per_epoch = NUM_SAMPLES // batch_size + + seed = np.random.randint(2**32 - 1) + + img_loader = celeba_loader(batch_size, num_child=3, seed=seed) + dis_loader = discriminator_loader(img_loader, seed=seed) + dec_loader = decoder_loader(img_loader, seed=seed) + enc_loader = encoder_loader(img_loader) + + models = [discriminator_train, decoder_train, encoder_train] + generators = [dis_loader, dec_loader, enc_loader] + metrics = [{'di_l': 1, 'di_l_t': 2, 'di_l_p': 3, 'di_a': 4, 'di_a_t': 7, 'di_a_p': 10}, {'de_l_t': 1, 'de_l_p': 2, 'de_a_t': 3, 'de_a_p': 5}, {'en_l': 0}] + + histories = fit_models(vaegan, models, generators, metrics, batch_size, + steps_per_epoch=steps_per_epoch, callbacks=callbacks, + epochs=epochs, initial_epoch=initial_epoch) + + vaegan.save_weights('trained.h5') + + x = next(celeba_loader(1)) + + x_tilde = vae.predict(x) + + plt.subplot(211) + plt.imshow((x[0].squeeze() + 1.) / 2.) + + plt.subplot(212) + plt.imshow((x_tilde[0].squeeze() + 1.) / 2.) + + plt.show() + + +if __name__ == '__main__': + main() diff --git a/vaegan/callbacks.py b/vaegan/callbacks.py new file mode 100644 index 0000000..397bb42 --- /dev/null +++ b/vaegan/callbacks.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +from concurrent.futures import ThreadPoolExecutor + +import numpy as np +from PIL import Image + +from keras.callbacks import Callback + + +class DecoderSnapshot(Callback): + + def __init__(self, step_size=200, latent_dim=128, decoder_index=-2): + super().__init__() + self._step_size = step_size + self._steps = 0 + self._epoch = 0 + self._latent_dim = latent_dim + self._decoder_index = decoder_index + self._img_rows = 64 + self._img_cols = 64 + self._thread_pool = ThreadPoolExecutor(1) + + def on_epoch_begin(self, epoch, logs=None): + self._epoch = epoch + self._steps = 0 + + def on_batch_begin(self, batch, logs=None): + self._steps += 1 + if self._steps % self._step_size == 0: + self.plot_images() + + def plot_images(self, samples=16): + decoder = self.model.layers[self._decoder_index] + filename = 'generated_%d_%d.png' % (self._epoch, self._steps) + z = np.random.normal(size=(samples, self._latent_dim)) + images = decoder.predict(z) + self._thread_pool.submit(self.save_plot, images, filename) + + @staticmethod + def save_plot(images, filename): + images = (images + 1.) * 127.5 + images = np.clip(images, 0., 255.) + images = images.astype('uint8') + rows = [] + for i in range(0, len(images), 4): + rows.append(np.concatenate(images[i:(i + 4), :, :, :], axis=0)) + plot = np.concatenate(rows, axis=1).squeeze() + Image.fromarray(plot).save(filename) + + +class ModelsCheckpoint(Callback): + + def __init__(self, epoch_format, *models): + super().__init__() + self._epoch_format = epoch_format + self._models = models + + def on_epoch_end(self, epoch, logs=None): + suffix = self._epoch_format.format(epoch=epoch + 1, **logs) + for model in self._models: + model.save_weights(model.name + suffix) diff --git a/vaegan/data.py b/vaegan/data.py new file mode 100644 index 0000000..06fa161 --- /dev/null +++ b/vaegan/data.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 + +import os.path +import glob +from multiprocessing.pool import ThreadPool as Pool + +import numpy as np +from PIL import Image + + +NUM_SAMPLES = 202599 + +proj_root = os.path.split(os.path.dirname(__file__))[0] +images_path = os.path.join(proj_root, 'img_align_celeba_png', '*.png') + + +def _load_image(f): + im = Image.open(f) \ + .crop((0, 20, 178, 198)) \ + .resize((64, 64), Image.BICUBIC) + return np.asarray(im) + + +def celeba_loader(batch_size, normalize=True, num_child=4, seed=0, workers=8): + rng = np.random.RandomState(seed) + images = glob.glob(images_path) + + with Pool(workers) as p: + while True: + rng.shuffle(images) + for s in range(0, len(images), batch_size): + e = s + batch_size + batch_names = images[s:e] + batch_images = p.map(_load_image, batch_names) + batch_images = np.stack(batch_images) + + if normalize: + batch_images = batch_images / 127.5 - 1. + # To be sure + batch_images = np.clip(batch_images, -1., 1.) + + # Yield the same batch num_child times since the images will be consumed + # by num_child different child generators + for i in range(num_child): + yield batch_images + + +def mnist_loader(batch_size, normalize=True, num_child=4, seed=0, workers=8): + from keras.datasets import mnist + (x_train, _), (_, _) = mnist.load_data() + + x_train_new = np.zeros((x_train.shape[0], 64, 64), dtype='int32') + + for i, img in enumerate(x_train): + im = Image.fromarray(img).resize((64, 64), Image.BICUBIC) + x_train_new[i] = np.asarray(im) + + x_train = x_train_new.reshape(-1, 64, 64, 1) + del x_train_new + + if normalize: + x_train = x_train / 127.5 - 1. + # To be sure + x_train = np.clip(x_train, -1., 1.) + + rng = np.random.RandomState(seed) + while True: + rng.shuffle(x_train) + for s in range(0, len(x_train), batch_size): + e = s + batch_size + batch_images = x_train[s:e] + + # Yield the same batch num_child times since the images will be consumed + # by num_child different child generators + for i in range(num_child): + yield batch_images + + +def discriminator_loader(img_loader, latent_dim=128, seed=0): + rng = np.random.RandomState(seed) + while True: + x = next(img_loader) + batch_size = x.shape[0] + # Sample z from isotropic Gaussian + z_p = rng.normal(size=(batch_size, latent_dim)) + + y_real = np.ones((batch_size,), dtype='float32') + y_fake = np.zeros((batch_size,), dtype='float32') + + yield [x, z_p], [y_real, y_fake, y_fake] + + +def decoder_loader(img_loader, latent_dim=128, seed=0): + rng = np.random.RandomState(seed) + while True: + x = next(img_loader) + batch_size = x.shape[0] + # Sample z from isotropic Gaussian + z_p = rng.normal(size=(batch_size, latent_dim)) + # Label as real + y_real = np.ones((batch_size,), dtype='float32') + yield [x, z_p], [y_real, y_real] + + +def encoder_loader(img_loader): + while True: + x = next(img_loader) + yield x, None diff --git a/vaegan/losses.py b/vaegan/losses.py new file mode 100644 index 0000000..ba34cb4 --- /dev/null +++ b/vaegan/losses.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 + +import numpy as np + +from keras import backend as K + + +def mean_gaussian_negative_log_likelihood(y_true, y_pred): + nll = 0.5 * np.log(2 * np.pi) + 0.5 * K.square(y_pred - y_true) + axis = tuple(range(1, len(K.int_shape(y_true)))) + return K.mean(K.sum(nll, axis=axis), axis=-1) diff --git a/vaegan/models.py b/vaegan/models.py index 749f581..737d800 100644 --- a/vaegan/models.py +++ b/vaegan/models.py @@ -4,28 +4,30 @@ from keras import backend as K from keras.models import Sequential, Model -from keras.layers import Input, Conv2D, BatchNormalization, Activation, Dense, Conv2DTranspose, Flatten, Reshape, \ - Lambda +from keras.layers import Input, Conv2D, BatchNormalization, Dense, Conv2DTranspose, Flatten, Reshape, \ + Lambda, LeakyReLU, Activation from keras.regularizers import l2 +from .losses import mean_gaussian_negative_log_likelihood -def create_models(wdecay=1e-5, bn_mom=0.9, bn_eps=1e-6): - image_shape = (64, 64, 1) - n_channels = image_shape[-1] + +def create_models(n_channels=3, recon_depth=9, wdecay=1e-5, bn_mom=0.9, bn_eps=1e-6): + + image_shape = (64, 64, n_channels) n_encoder = 1024 n_discriminator = 512 latent_dim = 128 - epsilon_std = 1.0 decode_from_shape = (8, 8, 256) n_decoder = np.prod(decode_from_shape) - l2_regularizer = l2(wdecay) + leaky_relu_alpha = 0.2 - def conv_block(x, filters, transpose=False): + def conv_block(x, filters, leaky=True, transpose=False, name=''): conv = Conv2DTranspose if transpose else Conv2D + activation = LeakyReLU(leaky_relu_alpha) if leaky else Activation('relu') layers = [ - conv(filters, 5, strides=2, padding='same', kernel_regularizer=l2_regularizer), - BatchNormalization(momentum=bn_mom, epsilon=bn_eps), - Activation('relu') + conv(filters, 5, strides=2, padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name=name + 'conv'), + BatchNormalization(momentum=bn_mom, epsilon=bn_eps, name=name + 'bn'), + activation ] if x is None: return layers @@ -34,68 +36,128 @@ def conv_block(x, filters, transpose=False): return x # Encoder - enc_input = Input(shape=image_shape, name='input_image') - x = enc_input - for f in [64, 128, 256]: - x = conv_block(x, f) - x = Flatten()(x) - x = Dense(n_encoder, kernel_regularizer=l2_regularizer)(x) - x = BatchNormalization()(x) - x = Activation('relu')(x) + def create_encoder(): + x = Input(shape=image_shape, name='enc_input') - z_mean = Dense(latent_dim, name='z_mean')(x) - z_log_var = Dense(latent_dim, name='z_log_var')(x) + y = conv_block(x, 64, name='enc_blk_1_') + y = conv_block(y, 128, name='enc_blk_2_') + y = conv_block(y, 256, name='enc_blk_3_') + y = Flatten()(y) + y = Dense(n_encoder, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='enc_h_dense')(y) + y = BatchNormalization(name='enc_h_bn')(y) + y = LeakyReLU(leaky_relu_alpha)(y) - encoder = Model(enc_input, [z_mean, z_log_var], name='encoder') + z_mean = Dense(latent_dim, name='z_mean', kernel_initializer='he_uniform')(y) + z_log_var = Dense(latent_dim, name='z_log_var', kernel_initializer='he_uniform')(y) - def sampling(args): - z_mean, z_log_var = args - epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., - stddev=epsilon_std) - return z_mean + K.exp(z_log_var / 2) * epsilon - - sampling_layer = Lambda(sampling, output_shape=(latent_dim,)) + return Model(x, [z_mean, z_log_var], name='encoder') # Decoder decoder = Sequential([ - Dense(n_decoder, kernel_regularizer=l2_regularizer, input_shape=(latent_dim,)), - BatchNormalization(), - Activation('relu'), + Dense(n_decoder, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', input_shape=(latent_dim,), name='dec_h_dense'), + BatchNormalization(name='dec_h_bn'), + LeakyReLU(leaky_relu_alpha), Reshape(decode_from_shape), - *conv_block(None, 256, transpose=True), - *conv_block(None, 128, transpose=True), - *conv_block(None, 32, transpose=True), - Conv2D(n_channels, 5, activation='tanh', padding='same', kernel_regularizer=l2_regularizer, name='output_image') + *conv_block(None, 256, transpose=True, name='dec_blk_1_'), + *conv_block(None, 128, transpose=True, name='dec_blk_2_'), + *conv_block(None, 32, transpose=True, name='dec_blk_3_'), + Conv2D(n_channels, 5, activation='tanh', padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dec_output') ], name='decoder') # Discriminator - discriminator = Sequential([ - Conv2D(32, 5, activation='relu', padding='same', kernel_regularizer=l2_regularizer, input_shape=image_shape), - *conv_block(None, 128), - *conv_block(None, 256), - *conv_block(None, 256), - Flatten(), - Dense(n_discriminator, kernel_regularizer=l2_regularizer), - BatchNormalization(), - Activation('relu'), - Dense(1, activation='sigmoid', kernel_regularizer=l2_regularizer) - ], name='discriminator') - - vae = Model(encoder.inputs, decoder(sampling_layer(encoder.outputs)), name='vae') - - kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) - vae_loss = K.mean(kl_loss) - - vaegan = Model(vae.inputs, discriminator(vae.outputs), name='vaegan') - vaegan.add_loss(vae_loss) - - return encoder, decoder, discriminator, vae, vae_loss - - -# from keras.utils.vis_utils import plot_model -# e, d, dis, vae, gan = create_models() -# plot_model(e, show_shapes=True, to_file='encoder.png') -# plot_model(d, show_shapes=True, to_file='decoder.png') -# plot_model(dis, show_shapes=True, to_file='discriminator.png') -# plot_model(vae, show_shapes=True, to_file='vae.png') -# plot_model(gan, show_shapes=True, to_file=' gan.png') + def create_discriminator(): + x = Input(shape=image_shape, name='dis_input') + + layers = [ + Conv2D(32, 5, padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dis_blk_1_conv'), + LeakyReLU(leaky_relu_alpha), + *conv_block(None, 128, leaky=True, name='dis_blk_2_'), + *conv_block(None, 256, leaky=True, name='dis_blk_3_'), + *conv_block(None, 256, leaky=True, name='dis_blk_4_'), + Flatten(), + Dense(n_discriminator, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dis_dense'), + BatchNormalization(name='dis_bn'), + LeakyReLU(leaky_relu_alpha), + Dense(1, activation='sigmoid', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dis_output') + ] + + y = x + y_feat = None + for i, layer in enumerate(layers, 1): + y = layer(y) + # Output the features at the specified depth + if i == recon_depth: + y_feat = y + + return Model(x, [y, y_feat], name='discriminator') + + encoder = create_encoder() + discriminator = create_discriminator() + + return encoder, decoder, discriminator + + +def _sampling(args): + """Reparameterization trick by sampling fr an isotropic unit Gaussian. + Instead of sampling from Q(z|X), sample eps = N(0,I) + + # Arguments: + args (tensor): mean and log of variance of Q(z|X) + # Returns: + z (tensor): sampled latent vector + """ + z_mean, z_log_var = args + batch = K.shape(z_mean)[0] + dim = K.int_shape(z_mean)[1] + # by default, random_normal has mean=0 and std=1.0 + epsilon = K.random_normal(shape=(batch, dim)) + return z_mean + K.exp(0.5 * z_log_var) * epsilon + + +def build_graph(encoder, decoder, discriminator, recon_vs_gan_weight=1e-6): + image_shape = K.int_shape(encoder.input)[1:] + latent_shape = K.int_shape(decoder.input)[1:] + + sampler = Lambda(_sampling, output_shape=latent_shape, name='sampler') + + # Inputs + x = Input(shape=image_shape, name='input_image') + # z_p is sampled directly from isotropic gaussian + z_p = Input(shape=latent_shape, name='z_p') + + # Build computational graph + + z_mean, z_log_var = encoder(x) + z = sampler([z_mean, z_log_var]) + + x_tilde = decoder(z) + x_p = decoder(z_p) + + dis_x, dis_feat = discriminator(x) + dis_x_tilde, dis_feat_tilde = discriminator(x_tilde) + dis_x_p = discriminator(x_p)[0] + + # Compute losses + + # Learned similarity metric + dis_nll_loss = mean_gaussian_negative_log_likelihood(dis_feat, dis_feat_tilde) + + # KL divergence loss + kl_loss = K.mean(-0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)) + + # Create models for training + encoder_train = Model(x, dis_feat_tilde, name='e') + encoder_train.add_loss(kl_loss) + encoder_train.add_loss(dis_nll_loss) + + decoder_train = Model([x, z_p], [dis_x_tilde, dis_x_p], name='de') + normalized_weight = recon_vs_gan_weight / (1. - recon_vs_gan_weight) + decoder_train.add_loss(normalized_weight * dis_nll_loss) + + discriminator_train = Model([x, z_p], [dis_x, dis_x_tilde, dis_x_p], name='di') + + # Additional models for testing + vae = Model(x, x_tilde, name='vae') + vaegan = Model(x, dis_x_tilde, name='vaegan') + + return encoder_train, decoder_train, discriminator_train, vae, vaegan diff --git a/vaegan/training.py b/vaegan/training.py new file mode 100644 index 0000000..908ad3e --- /dev/null +++ b/vaegan/training.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 + +from keras import callbacks as cbks + + +def fit_models(callback_model, + models, + generators, + metrics_names, + batch_size, + steps_per_epoch=None, + epochs=1, + verbose=1, + callbacks=None, + initial_epoch=0): + epoch = initial_epoch + + # Prepare display labels. + callback_metrics = [n for m in metrics_names for n in m.keys()] + + # prepare callbacks + stateful_metric_names = [] + for model in models: + model.history = cbks.History() + try: + stateful_metric_names.extend(model.stateful_metric_names) + except AttributeError: + stateful_metric_names.extend(model.model.stateful_metric_names) + _callbacks = [cbks.BaseLogger( + stateful_metrics=stateful_metric_names)] + if verbose: + _callbacks.append( + cbks.ProgbarLogger( + count_mode='steps', + stateful_metrics=stateful_metric_names)) + _callbacks += (callbacks or []) + [model.history for model in models] + callbacks = cbks.CallbackList(_callbacks) + + # it's possible to callback a different model than self: + callbacks.set_model(callback_model) + callbacks.set_params({ + 'epochs': epochs, + 'steps': steps_per_epoch, + 'verbose': verbose, + 'do_validation': False, + 'metrics': callback_metrics, + }) + callbacks.on_train_begin() + + try: + callback_model.stop_training = False + # Construct epoch logs. + epoch_logs = {} + while epoch < epochs: + for model in models: + try: + stateful_metric_functions = model.stateful_metric_functions + except AttributeError: + stateful_metric_functions = model.model.stateful_metric_functions + for m in stateful_metric_functions: + m.reset_states() + callbacks.on_epoch_begin(epoch) + steps_done = 0 + batch_index = 0 + while steps_done < steps_per_epoch: + + # build batch logs + batch_logs = {} + batch_logs['batch'] = batch_index + batch_logs['size'] = batch_size + callbacks.on_batch_begin(batch_index, batch_logs) + + for model, output_generator, metrics in zip(models, generators, metrics_names): + + generator_output = next(output_generator) + + if not hasattr(generator_output, '__len__'): + raise ValueError('Output of generator should be ' + 'a tuple `(x, y, sample_weight)` ' + 'or `(x, y)`. Found: ' + + str(generator_output)) + + if len(generator_output) == 2: + x, y = generator_output + sample_weight = None + elif len(generator_output) == 3: + x, y, sample_weight = generator_output + else: + raise ValueError('Output of generator should be ' + 'a tuple `(x, y, sample_weight)` ' + 'or `(x, y)`. Found: ' + + str(generator_output)) + + outs = model.train_on_batch(x, y, sample_weight=sample_weight) + + if not isinstance(outs, list): + outs = [outs] + + for name, i in metrics.items(): + batch_logs[name] = outs[i] + + callbacks.on_batch_end(batch_index, batch_logs) + + batch_index += 1 + steps_done += 1 + + # Epoch finished. + if callback_model.stop_training: + break + + callbacks.on_epoch_end(epoch, epoch_logs) + epoch += 1 + if callback_model.stop_training: + break + + finally: + pass + + callbacks.on_train_end() + + return [model.history for model in models]