Skip to content

Commit

Permalink
initial code
Browse files Browse the repository at this point in the history
  • Loading branch information
andersbll committed Nov 19, 2015
1 parent 8cdb5c3 commit c1824aa
Show file tree
Hide file tree
Showing 9 changed files with 1,218 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
datasets
plots
savestates/*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
174 changes: 174 additions & 0 deletions cond_vaegan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from copy import deepcopy
import numpy as np
import cudarray as ca
import deeppy as dp
import deeppy.expr as expr

from vaegan import KLDivergence, NegativeGradient, SquareError


class AppendSpatially(expr.base.Binary):
def __call__(self, imgs, feats):
self.imgs = imgs
self.feats = feats
self.inputs = [imgs, feats]
return self

def setup(self):
b, c, h, w = self.imgs.out_shape
b_, f = self.feats.out_shape
if b != b_:
raise ValueError('batch size mismatch')
self.out_shape = (b, c+f, h, w)
self.out = ca.empty(self.out_shape)
self.out_grad = ca.empty(self.out_shape)
self.tmp = ca.zeros((b, f, h, w))

def fprop(self):
self.tmp.fill(0.0)
feats = ca.reshape(self.feats.out, self.feats.out.shape + (1, 1))
ca.add(feats, self.tmp, out=self.tmp)
ca.extra.concatenate(self.imgs.out, self.tmp, axis=1, out=self.out)

def bprop(self):
ca.extra.split(self.out_grad, a_size=self.imgs.out_shape[1], axis=1,
out_a=self.imgs.out_grad, out_b=self.tmp)


class ConditionalSequential(expr.Sequential):
def __call__(self, x, y):
for op in self.collection:
if isinstance(op, (expr.Concatenate, AppendSpatially)):
x = op(x, y)
else:
x = op(x)
return x


class ConditionalVAEGAN(dp.base.Model):
def __init__(self, encoder, sampler, generator, discriminator, mode,
reconstruct_error=None):
self.encoder = encoder
self.sampler = sampler
self.generator = generator
self.mode = mode
self.discriminator = discriminator
self.eps = 1e-4
if reconstruct_error is None:
reconstruct_error = SquareError()
self.reconstruct_error = reconstruct_error
if self.mode == 'vaegan':
self.generator_neg = deepcopy(generator)
self.generator_neg.params = [p.share() for p in generator.params]

def _embed_expr(self, x, y):
h_enc = self.encoder(x, y)
z, z_mu, z_log_sigma, z_eps = self.sampler(h_enc)
z = z_mu
return z

def _reconstruct_expr(self, z, y):
return self.generator(z, y)

def setup(self, x_shape, y_shape):
batch_size = x_shape[0]
self.sampler.batch_size = x_shape[0]
self.x_src = expr.Source(x_shape)
self.y_src = expr.Source(y_shape)

if self.mode in ['vae', 'vaegan']:
h_enc = self.encoder(self.x_src, self.y_src)
z, z_mu, z_log_sigma, z_eps = self.sampler(h_enc)
self.kld = KLDivergence()(z_mu, z_log_sigma)
x_tilde = self.generator(z, self.y_src)
# if self.mode == 'vaegan':
# x_tilde = ScaleGradient()(x_tilde)
self.logpxz = self.reconstruct_error(x_tilde, self.x_src)
loss = self.kld + expr.sum(self.logpxz)

if self.mode in ['gan', 'vaegan']:
y = self.y_src
if self.mode == 'gan':
z = self.sampler.samples()
x_tilde = self.generator(z, y)
x_tilde = NegativeGradient()(x_tilde)
gen_size = batch_size
elif self.mode == 'vaegan':
z = NegativeGradient()(z)
z = expr.Concatenate(axis=0)(z, z_eps)
y = expr.Concatenate(axis=0)(y, self.y_src)
x_tilde = self.generator_neg(z, y)
x_tilde = NegativeGradient()(x_tilde)
gen_size = batch_size*2
x = expr.Concatenate(axis=0)(self.x_src, x_tilde)
y = expr.Concatenate(axis=0)(y, self.y_src)
d = self.discriminator(x, y)
d = expr.clip(d, self.eps, 1.0-self.eps)

real_size = batch_size
sign = np.ones((real_size + gen_size, 1), dtype=ca.float_)
sign[real_size:] = -1.0
offset = np.zeros_like(sign)
offset[real_size:] = 1.0

self.gan_loss = expr.log(d*sign + offset)
if self.mode == 'gan':
loss = expr.sum(-self.gan_loss)
elif self.mode == 'vaegan':
loss = loss + expr.sum(-self.gan_loss)

self._graph = expr.ExprGraph(loss)
self._graph.out_grad = ca.array(1.0)
self._graph.setup()

@property
def params(self):
enc_params = []
gen_params = self.generator.params
dis_params = []
if self.mode != 'vae':
dis_params = self.discriminator.params
if self.mode != 'gan':
enc_params = self.encoder.params + self.sampler.params
return enc_params, gen_params, dis_params

def update(self, x, y):
self.x_src.out = x
self.y_src.out = y
self._graph.fprop()
self._graph.bprop()
kld = 0
d_x_loss = 0
d_z_loss = 0
if self.mode != 'gan':
kld = np.array(self.kld.out)
if self.mode != 'vae':
gan_loss = -np.array(self.gan_loss.out)
batch_size = x.shape[0]
d_x_loss = float(np.mean(gan_loss[:batch_size]))
d_z_loss = float(np.mean(gan_loss[batch_size:]))
return d_x_loss, d_z_loss, kld

def _batchwise(self, x, y, expr_fun):
x = dp.input.Input.from_any(x)
y = dp.input.Input.from_any(y)
x_src = expr.Source(x.x_shape)
y_src = expr.Source(y.x_shape)
graph = expr.ExprGraph(expr_fun(x_src, y_src))
graph.setup()
out = []
for x_batch, y_batch in zip(x.batches(), y.batches()):
x_src.out = x_batch['x']
y_src.out = y_batch['x']
graph.fprop()
out.append(np.array(graph.out))
out = np.concatenate(out)[:x.n_samples]
return out

def embed(self, x, y):
""" Input to hidden. """
return self._batchwise(x, y, self._embed_expr)

def reconstruct(self, z, y):
""" Hidden to input. """
return self._batchwise(z, y, self._reconstruct_expr)
214 changes: 214 additions & 0 deletions cond_vaegan_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
#!/usr/bin/env python

import pickle
import numpy as np
import scipy as sp
import deeppy as dp
import deeppy.expr as expr

import cond_vaegan
import vaegan
from util import img_tile, one_hot, random_walk
from video import Video


def affine(n_out, gain):
return expr.nnet.Affine(n_out=n_out, weights=dp.AutoFiller(gain))


def conv(n_filters, filter_size, gain=1.0):
return expr.nnet.Convolution(
n_filters=n_filters, strides=(1, 1), weights=dp.AutoFiller(gain),
filter_shape=(filter_size, filter_size), border_mode='same',
)


def pool(method='max'):
return expr.nnet.Pool(win_shape=(3, 3), method=method, strides=(2, 2),
border_mode='same')


def upscale():
return expr.nnet.Rescale(factor=2, method='perforated')


def model_expressions(img_shape):
n_channels = img_shape[0]
gain = 1.0
sigma = 0.001
n_encoder = 1024
n_discriminator = 1024
n_hidden = 512
hidden_shape = (128, 8, 8)
n_generator = np.prod(hidden_shape)

encoder = cond_vaegan.ConditionalSequential([
conv(32, 5, gain=gain),
pool(),
expr.nnet.ReLU(),
conv(64, 5, gain=gain),
pool(),
expr.nnet.ReLU(),
conv(96, 3, gain=gain),
expr.nnet.ReLU(),
expr.Reshape((-1, 96*8*8)),
expr.Concatenate(axis=1),
affine(n_encoder, gain),
expr.nnet.ReLU(),
])
sampler = vaegan.NormalSampler(
n_hidden,
weight_filler=dp.AutoFiller(gain),
bias_filler=dp.NormalFiller(sigma),
)
generator = cond_vaegan.ConditionalSequential([
expr.Concatenate(axis=1),
affine(n_generator, gain),
expr.nnet.BatchNormalization(),
expr.Reshape((-1,) + hidden_shape),
upscale(),
expr.nnet.ReLU(),
cond_vaegan.AppendSpatially(),
conv(256, 5, gain=gain),
expr.nnet.SpatialBatchNormalization(),
upscale(),
expr.nnet.ReLU(),
cond_vaegan.AppendSpatially(),
conv(128, 5, gain=gain),
expr.nnet.SpatialBatchNormalization(),
expr.nnet.ReLU(),
cond_vaegan.AppendSpatially(),
conv(128, 5, gain=gain),
expr.nnet.SpatialBatchNormalization(),
expr.nnet.ReLU(),
conv(n_channels, 3, gain=gain),
])
discriminator = cond_vaegan.ConditionalSequential([
conv(32, 5, gain=gain),
pool(),
expr.nnet.ReLU(),
expr.nnet.SpatialDropout(0.2),
conv(64, 5, gain=gain),
pool(),
expr.nnet.ReLU(),
expr.nnet.SpatialDropout(0.2),
conv(96, 3, gain=gain),
expr.nnet.ReLU(),
expr.nnet.SpatialDropout(0.2),
expr.Reshape((-1, 96*8*8)),
expr.Concatenate(axis=1),
affine(n_discriminator, gain),
expr.nnet.ReLU(),
expr.nnet.Dropout(0.5),
affine(1, gain),
expr.nnet.Sigmoid(),
])
return encoder, sampler, generator, discriminator


def clip_range(imgs):
return np.tanh(imgs*0.5)


def run():
mode = 'gan'
experiment_name = mode
filename = 'savestates/cifar_cond_' + experiment_name + '.pickle'
in_filename = filename
in_filename = None
print('experiment_name', experiment_name)
print('in_filename', in_filename)
print('filename', filename)

# Fetch dataset
dataset = dp.dataset.CIFAR10()
x_train, y_train, x_test, y_test = dataset.arrays(dp_dtypes=True)
n_classes = dataset.n_classes

# Normalize pixel intensities
scaler = dp.StandardScaler()
x_train = scaler.fit_transform(x_train)
x_test = scaler.transform(x_test)
y_train = one_hot(y_train, n_classes).astype(dp.float_)
y_test = one_hot(y_test, n_classes).astype(dp.float_)

# Setup network
if in_filename is None:
print('Creating new model')
img_shape = x_train.shape[1:]
expressions = model_expressions(img_shape)
else:
print('Starting from %s' % in_filename)
with open(in_filename, 'rb') as f:
expressions = pickle.load(f)

encoder, sampler, generator, discriminator = expressions
model = cond_vaegan.ConditionalVAEGAN(
encoder=encoder,
sampler=sampler,
generator=generator,
discriminator=discriminator,
mode=mode,
)

# Prepare network inputs
batch_size = 64
train_input = dp.SupervisedInput(x_train, y_train, batch_size=batch_size,
epoch_size=150)

# Plotting
n_examples = 100
examples = x_test[:n_examples]
examples_y = y_test[:n_examples]
samples_z = np.random.normal(size=(n_examples, model.sampler.n_hidden))
samples_z = samples_z.astype(dp.float_)
samples_y = ((np.arange(n_examples) // 10) % n_classes)
samples_y = one_hot(samples_y, n_classes).astype(dp.float_)

recon_video = Video('plots/cifar_' + experiment_name +
'_reconstruction.mp4')
sample_video = Video('plots/cifar_' + experiment_name + '_samples.mp4')
sp.misc.imsave('cifar_examples.png', img_tile(dp.misc.to_b01c(examples)))

def plot():
examples_z = model.embed(examples, examples_y)
examples_recon = model.reconstruct(examples_z, examples_y)
examples_recon = clip_range(examples_recon)
recon_video.append(img_tile(dp.misc.to_b01c(examples_recon)))
samples = clip_range(model.reconstruct(samples_z, samples_y))
sample_video.append(img_tile(dp.misc.to_b01c(samples)))
model.setup(**train_input.shapes)

# Train network
runs = [
(150, dp.RMSProp(learn_rate=0.1)),
(150, dp.RMSProp(learn_rate=0.08)),
(150, dp.RMSProp(learn_rate=0.06)),
(150, dp.RMSProp(learn_rate=0.04)),
(25, dp.RMSProp(learn_rate=0.01)),
]
try:
for n_epochs, learn_rule in runs:
if mode == 'vae':
vaegan.train(model, train_input, learn_rule, n_epochs,
epoch_callback=plot)
else:
vaegan.margin_train(model, train_input, learn_rule, n_epochs,
epoch_callback=plot)
except KeyboardInterrupt:
pass

raw_input('\n\nsave model to %s?\n' % filename)
with open(filename, 'wb') as f:
expressions = encoder, sampler, generator, discriminator
pickle.dump(expressions, f)

print('Generating latent space video')
walk_video = Video('plots/cifar_' + experiment_name + '_walk.mp4')
for z in random_walk(samples_z, 500, step_std=0.15):
samples = clip_range(model.reconstruct(z, samples_y))
walk_video.append(img_tile(dp.misc.to_b01c(samples)))


if __name__ == '__main__':
run()
Loading

0 comments on commit c1824aa

Please sign in to comment.