-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
30 changed files
with
2,121 additions
and
1,280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,37 @@ | ||
# vae_gan | ||
Playground for variational autoencoders and generative adversarial nets | ||
## Autoencoding beyond pixels using a learned similarity measure | ||
|
||
*[Anders Boesen Lindbo Larsen](https://github.com/andersbll)*, *[Søren Kaae Sønderby](https://github.com/skaae)*, *[Hugo Larochelle](http://www.dmi.usherb.ca/~larocheh)*, *[Ole Winther](http://cogsys.imm.dtu.dk/staff/winther)* | ||
|
||
Implementation of the method described in our [Arxiv paper](http://arxiv.org/abs/1512.09300). | ||
|
||
|
||
### Abstract | ||
We present an autoencoder that leverages learned representations to better measure similarities in data space. | ||
By combining a variational autoencoder with a generative adversarial network we can use learned feature representations in the GAN discriminator as basis for the VAE reconstruction objective. | ||
Thereby, we replace element-wise errors with feature-wise errors to better capture the data distribution while offering invariance towards e.g. translation. | ||
We apply our method to images of faces and show that it outperforms VAEs with element-wise similarity measures in terms of visual fidelity. | ||
Moreover, we show that the method learns an embedding in which high-level abstract visual features (e.g. wearing glasses) can be modified using simple arithmetic. | ||
|
||
|
||
### Getting started | ||
We have tried automatizing everything from data fetching to generating pretty images. | ||
This means that you can get started in two steps: | ||
|
||
1. Install [CUDArray](https://github.com/andersbll/cudarray) and [DeepPy](https://github.com/andersbll/deeppy). | ||
2. Run `python celeba_aegan.py`. | ||
|
||
You can also try out the other scripts if you want to experiment with different models/datasets. | ||
|
||
|
||
### Examples | ||
Coming soon ... | ||
|
||
|
||
### Implementation references | ||
We wish to thank the authors of the following projects for inspiration. | ||
Our method would never have gotten off the ground without the insights gained from inspecting their code. | ||
- [The Eyescream Project](https://github.com/facebook/eyescream). | ||
- [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://github.com/Newmu/dcgan_code). | ||
- Joost van Amersfoort's VAE implementations ([Theano](https://github.com/y0ast/Variational-Autoencoder) and [Torch](https://github.com/y0ast/VAE-Torch)). | ||
- [Ian Goodfellow's GAN implementation](https://github.com/goodfeli/adversarial). | ||
- [Parmesan](https://github.com/casperkaae/parmesan) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
import numpy as np | ||
import scipy as sp | ||
import deeppy as dp | ||
|
||
import architectures | ||
from model import aegan | ||
from video import Video | ||
import output | ||
from dataset.util import img_inverse_transform | ||
|
||
|
||
def build_model(experiment_name, img_size, n_hidden=128, recon_depth=9, | ||
recon_vs_gan_weight=5e-5, real_vs_gen_weight=0.5, | ||
sample_z=True, wgain=1.0, wdecay=1e-5, bn_noise_std=0.0): | ||
experiment_name += '_reconganweight%.1e' % recon_vs_gan_weight | ||
if recon_depth > 0: | ||
experiment_name += '_recondepth%i' % recon_depth | ||
if not np.isclose(real_vs_gen_weight, 0.5): | ||
experiment_name += '_realgenweight%.2f' % real_vs_gen_weight | ||
if not sample_z: | ||
experiment_name += '_nosamplez' | ||
if not np.isclose(wgain, 1.0): | ||
experiment_name += '_wgain%.1e' % wgain | ||
if not np.isclose(wdecay, 1e-5): | ||
experiment_name += '_wdecay%.1e' % wdecay | ||
if not np.isclose(bn_noise_std, 0.0): | ||
experiment_name += '_bnnoise%.2f' % bn_noise_std | ||
|
||
# Setup network | ||
if img_size == 32: | ||
encoder, decoder, discriminator = architectures.img32x32( | ||
wgain=wgain, wdecay=wdecay, bn_noise_std=bn_noise_std | ||
) | ||
elif img_size == 64: | ||
encoder, decoder, discriminator = architectures.img64x64( | ||
wgain=wgain, wdecay=wdecay, bn_noise_std=bn_noise_std | ||
) | ||
else: | ||
raise ValueError('no architecture for img_size %i' % img_size) | ||
latent_encoder = architectures.vae_latent_encoder(n_hidden) | ||
model = aegan.AEGAN( | ||
encoder=encoder, | ||
latent_encoder=latent_encoder, | ||
decoder=decoder, | ||
discriminator=discriminator, | ||
recon_depth=recon_depth, | ||
sample_z=sample_z, | ||
recon_vs_gan_weight=recon_vs_gan_weight, | ||
real_vs_gen_weight=real_vs_gen_weight, | ||
) | ||
return model, experiment_name | ||
|
||
|
||
def train(model, output_dir, train_input, test_input, lr_start=0.02, | ||
lr_stop=0.00001, lr_gamma=0.75, n_epochs=150, gan_margin=0.35): | ||
n_hidden = model.latent_encoder.n_out | ||
|
||
# For plotting | ||
original_x = np.array(test_input.batches().next()['x']) | ||
samples_z = np.random.normal(size=(len(original_x), n_hidden)) | ||
samples_z = (samples_z).astype(dp.float_) | ||
recon_video = Video(os.path.join(output_dir, 'convergence_recon.mp4')) | ||
sample_video = Video(os.path.join(output_dir, 'convergence_samples.mp4')) | ||
original_x_ = original_x | ||
original_x_ = img_inverse_transform(original_x) | ||
sp.misc.imsave(os.path.join(output_dir, 'examples.png'), | ||
dp.misc.img_tile(original_x_)) | ||
|
||
# Train network | ||
learn_rule = dp.RMSProp() | ||
annealer = dp.GammaAnnealer(lr_start, lr_stop, n_epochs, gamma=lr_gamma) | ||
trainer = aegan.GradientDescent(model, train_input, learn_rule, | ||
margin=gan_margin) | ||
try: | ||
for e in range(n_epochs): | ||
model.phase = 'train' | ||
model.setup(**train_input.shapes) | ||
learn_rule.learn_rate = annealer.value(e) / train_input.batch_size | ||
trainer.train_epoch() | ||
|
||
model.phase = 'test' | ||
original_z = model.encode(original_x) | ||
recon_x = model.decode(original_z) | ||
samples_x = model.decode(samples_z) | ||
recon_x = img_inverse_transform(recon_x) | ||
samples_x = img_inverse_transform(samples_x) | ||
recon_video.append(dp.misc.img_tile(recon_x)) | ||
sample_video.append(dp.misc.img_tile(samples_x)) | ||
except KeyboardInterrupt: | ||
pass | ||
|
||
model.phase = 'test' | ||
n_examples = 100 | ||
test_input.reset() | ||
original_x = np.array(test_input.batches().next()['x'])[:n_examples] | ||
samples_z = np.random.normal(size=(n_examples, n_hidden)) | ||
output.samples(model, samples_z, output_dir, img_inverse_transform) | ||
output.reconstructions(model, original_x, output_dir, | ||
img_inverse_transform) | ||
original_z = model.encode(original_x) | ||
output.walk(model, original_z, output_dir, img_inverse_transform) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
import numpy as np | ||
import deeppy as dp | ||
import deeppy.expr as expr | ||
|
||
import model.ae | ||
|
||
|
||
def affine(n_out, gain, wdecay=0.0, bias=0.0): | ||
return expr.nnet.Affine( | ||
n_out=n_out, bias=bias, | ||
weights=dp.Parameter(dp.AutoFiller(gain), weight_decay=wdecay), | ||
) | ||
|
||
|
||
def conv(n_filters, filter_size, stride=1, gain=1.0, wdecay=0.0, | ||
bias=0.0, border_mode='same'): | ||
return expr.nnet.Convolution( | ||
n_filters=n_filters, strides=(stride, stride), | ||
weights=dp.Parameter(dp.AutoFiller(gain), weight_decay=wdecay), | ||
bias=bias, filter_shape=(filter_size, filter_size), | ||
border_mode=border_mode, | ||
) | ||
|
||
|
||
def backconv(n_filters, filter_size, stride=2, gain=1.0, wdecay=0.0, | ||
bias=0.0): | ||
return expr.nnet.BackwardConvolution( | ||
n_filters=n_filters, strides=(stride, stride), | ||
weights=dp.Parameter(dp.AutoFiller(gain), weight_decay=wdecay), | ||
bias=bias, filter_shape=(filter_size, filter_size), border_mode='same', | ||
) | ||
|
||
|
||
def pool(method='max', win_size=3, stride=2, border_mode='same'): | ||
return expr.nnet.Pool(win_shape=(win_size, win_size), method=method, | ||
strides=(stride, stride), border_mode=border_mode) | ||
|
||
|
||
def vae_latent_encoder(n_hidden): | ||
latent_encoder = model.ae.NormalEncoder(n_hidden, dp.AutoFiller()) | ||
return latent_encoder | ||
|
||
|
||
def aae_latent_encoder(n_hidden, n_discriminator=1024, recon_weight=0.025): | ||
wgain = 1.0 | ||
discriminator = dp.expr.Sequential([ | ||
affine(n_discriminator, wgain, bias=None), | ||
expr.nnet.BatchNormalization(), | ||
expr.nnet.ReLU(), | ||
affine(n_discriminator, wgain, bias=None), | ||
expr.nnet.BatchNormalization(), | ||
expr.nnet.ReLU(), | ||
affine(1, wgain), | ||
expr.nnet.Sigmoid(), | ||
]) | ||
latent_encoder = model.ae.AdversarialEncoder( | ||
n_hidden, discriminator, dp.AutoFiller(), recon_weight=recon_weight, | ||
) | ||
return latent_encoder | ||
|
||
|
||
def mnist(wgain=1.0, wdecay=0, bn_noise_std=0.0, n_units=1024): | ||
img_shape = (28, 28) | ||
n_in = np.prod(img_shape) | ||
n_encoder = n_units | ||
n_decoder = n_units | ||
n_discriminator = n_units | ||
|
||
def block(n_out): | ||
return [ | ||
affine(n_encoder, wgain, wdecay=wdecay), | ||
expr.nnet.BatchNormalization(noise_std=bn_noise_std), | ||
expr.nnet.ReLU(), | ||
] | ||
encoder = dp.expr.Sequential( | ||
block(n_encoder) + | ||
block(n_encoder) | ||
) | ||
decoder = dp.expr.Sequential( | ||
block(n_decoder) + | ||
block(n_decoder) + | ||
[ | ||
affine(n_in, wgain), | ||
expr.nnet.Sigmoid(), | ||
] | ||
) | ||
discriminator = dp.expr.Sequential( | ||
block(n_discriminator) + | ||
block(n_discriminator) + | ||
[ | ||
affine(1, wgain), | ||
expr.nnet.Sigmoid(), | ||
] | ||
) | ||
return encoder, decoder, discriminator | ||
|
||
|
||
def img32x32(wgain=1.0, wdecay=1e-5, bn_mom=0.9, bn_eps=1e-6, | ||
bn_noise_std=0.0): | ||
n_channels = 3 | ||
n_encoder = 1024 | ||
n_discriminator = 512 | ||
decode_from_shape = (256, 4, 4) | ||
n_decoder = np.prod(decode_from_shape) | ||
|
||
def conv_block(n_filters, backward=False): | ||
block = [] | ||
if backward: | ||
block.append(backconv(n_filters, 5, stride=2, gain=wgain, | ||
wdecay=wdecay, bias=None)) | ||
else: | ||
block.append(conv(n_filters, 5, stride=2, gain=wgain, | ||
wdecay=wdecay, bias=None)) | ||
block.append(expr.nnet.SpatialBatchNormalization( | ||
momentum=bn_mom, eps=bn_eps, noise_std=bn_noise_std | ||
)) | ||
block.append(expr.nnet.ReLU()) | ||
return block | ||
|
||
encoder = dp.expr.Sequential( | ||
conv_block(64) + | ||
conv_block(128) + | ||
conv_block(256) + | ||
[ | ||
expr.Reshape((-1, 256*4*4)), | ||
affine(n_encoder, gain=wgain, wdecay=wdecay, bias=None), | ||
expr.nnet.BatchNormalization(noise_std=bn_noise_std), | ||
expr.nnet.ReLU(), | ||
] | ||
) | ||
|
||
decoder = dp.expr.Sequential( | ||
[ | ||
affine(n_decoder, gain=wgain, wdecay=wdecay, bias=None), | ||
expr.nnet.BatchNormalization(noise_std=bn_noise_std), | ||
expr.nnet.ReLU(), | ||
expr.Reshape((-1,) + decode_from_shape), | ||
] + | ||
conv_block(192, backward=True) + | ||
conv_block(128, backward=True) + | ||
conv_block(32, backward=True) + | ||
[ | ||
conv(n_channels, 5, wdecay=wdecay, gain=wgain), | ||
expr.Tanh(), | ||
] | ||
) | ||
|
||
discriminator = dp.expr.Sequential( | ||
[ | ||
conv(32, 5, wdecay=wdecay, gain=wgain), | ||
expr.nnet.ReLU(), | ||
] + | ||
conv_block(128) + | ||
conv_block(192) + | ||
conv_block(256) + | ||
[ | ||
expr.Reshape((-1, 256*4*4)), | ||
affine(n_discriminator, gain=wgain, wdecay=wdecay, bias=None), | ||
expr.nnet.BatchNormalization(noise_std=bn_noise_std), | ||
expr.nnet.ReLU(), | ||
affine(1, gain=wgain, wdecay=wdecay), | ||
expr.nnet.Sigmoid(), | ||
] | ||
) | ||
return encoder, decoder, discriminator | ||
|
||
|
||
def img64x64(wgain=1.0, wdecay=1e-5, bn_mom=0.9, bn_eps=1e-6, | ||
bn_noise_std=0.0): | ||
n_channels = 3 | ||
n_encoder = 1024 | ||
n_discriminator = 512 | ||
decode_from_shape = (256, 8, 8) | ||
n_decoder = np.prod(decode_from_shape) | ||
|
||
def conv_block(n_filters, backward=False): | ||
block = [] | ||
if backward: | ||
block.append(backconv(n_filters, 5, stride=2, gain=wgain, | ||
wdecay=wdecay, bias=None)) | ||
else: | ||
block.append(conv(n_filters, 5, stride=2, gain=wgain, | ||
wdecay=wdecay, bias=None)) | ||
block.append(expr.nnet.SpatialBatchNormalization( | ||
momentum=bn_mom, eps=bn_eps, noise_std=bn_noise_std | ||
)) | ||
block.append(expr.nnet.ReLU()) | ||
return block | ||
|
||
encoder = dp.expr.Sequential( | ||
conv_block(64) + | ||
conv_block(128) + | ||
conv_block(256) + | ||
[ | ||
expr.Reshape((-1, 256*8*8)), | ||
affine(n_encoder, gain=wgain, wdecay=wdecay, bias=None), | ||
expr.nnet.BatchNormalization(noise_std=bn_noise_std), | ||
expr.nnet.ReLU(), | ||
] | ||
) | ||
|
||
decoder = dp.expr.Sequential( | ||
[ | ||
affine(n_decoder, gain=wgain, wdecay=wdecay, bias=None), | ||
expr.nnet.BatchNormalization(noise_std=bn_noise_std), | ||
expr.nnet.ReLU(), | ||
expr.Reshape((-1,) + decode_from_shape), | ||
] + | ||
conv_block(256, backward=True) + | ||
conv_block(128, backward=True) + | ||
conv_block(32, backward=True) + | ||
[ | ||
conv(n_channels, 5, wdecay=wdecay, gain=wgain), | ||
expr.Tanh(), | ||
] | ||
) | ||
|
||
discriminator = dp.expr.Sequential( | ||
[ | ||
conv(32, 5, wdecay=wdecay, gain=wgain), | ||
expr.nnet.ReLU(), | ||
] + | ||
conv_block(128) + | ||
conv_block(256) + | ||
conv_block(256) + | ||
[ | ||
expr.Reshape((-1, 256*8*8)), | ||
affine(n_discriminator, gain=wgain, wdecay=wdecay, bias=None), | ||
expr.nnet.BatchNormalization(noise_std=bn_noise_std), | ||
expr.nnet.ReLU(), | ||
affine(1, gain=wgain, wdecay=wdecay), | ||
expr.nnet.Sigmoid(), | ||
] | ||
) | ||
return encoder, decoder, discriminator |
Oops, something went wrong.