Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ali #597

Merged
merged 11 commits into from
Apr 15, 2017
2 changes: 1 addition & 1 deletion edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
HMC, MetropolisHastings, SGLD, SGHMC, \
KLpq, KLqp, MFVI, ReparameterizationKLqp, ReparameterizationKLKLqp, \
ReparameterizationEntropyKLqp, ScoreKLqp, ScoreKLKLqp, ScoreEntropyKLqp, \
GANInference, ALI, WGANInference, MAP, Laplace
GANInference, BiGANInference, WGANInference, MAP, Laplace
from edward.models import PyMC3Model, PythonModel, StanModel, \
RandomVariable
from edward.util import copy, dot, get_ancestors, get_children, \
Expand Down
2 changes: 1 addition & 1 deletion edward/inferences/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import division
from __future__ import print_function

from edward.inferences.ali import *
from edward.inferences.bigan_inference import *
from edward.inferences.gan_inference import *
from edward.inferences.hmc import *
from edward.inferences.inference import *
Expand Down
84 changes: 84 additions & 0 deletions edward/inferences/bigan_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import tensorflow as tf

from edward.inferences.gan_inference import GANInference
from edward.util import get_session


class BiGANInference(GANInference):
"""Adversarially Learned Inference (Dumoulin et al., 2017) or
Bidirectional Generative Adversarial Networks (Donahue et al., 2017)
for joint learning of generator and inference networks.

Works for the class of implicit (and differentiable) probabilistic
models. These models do not require a tractable density and assume
only a program that generates samples.
"""
def __init__(self, latent_vars, data, discriminator):
"""
Notes
-----
``BiGANInference`` matches a mapping from data to latent variables and a
mapping from latent variables to data through a joint
discriminator.

In building the computation graph for inference, the
discriminator's parameters can be accessed with the variable scope
"Disc".
In building the computation graph for inference, the
encoder and decoder parameters can be accessed with the variable scope
"Gen".

Examples
--------
>>> with tf.variable_scope("Gen"):
>>> xf = gen_data(z_ph)
>>> zf = gen_latent(x_ph)
>>> inference = ed.BiGANInference({latent_vars: qz}, {xf: x_data}, discriminator)
Copy link
Member

@dustinvtran dustinvtran Apr 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this line be {z_ph: zf} and {xf: x_ph}?

"""
joint = latent_vars.copy()
joint.update(data)

super(BiGANInference, self).__init__(joint, discriminator)

def build_loss_and_gradients(self, var_list):

x_true = list(six.itervalues(self.data))[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a dictionary, self.data doesn't guarantee the first key-value pair is x and the second is z. I think you can keep the self.latent_vars / self.data attributes stored separately, doing something like this for __init__:

    if not callable(discriminator):
      raise TypeError("discriminator must be a callable function.")

    self.discriminator = discriminator
    # call grandparent's method; avoid parent (GANInference)
    super(GANInference, self).__init__(latent_vars, data)

In addition, for consistency with other inferences, the zs should be ordered as z_true: z_fake, i.e., the prior latent variable is binded to the output of the inference network. With the above you can do

    z_true = list(six.iterkeys(self.latent_vars))[0]
    z_fake = list(six.itervalues(self.latent_vars))[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's my mistake, I didn't realize that the order was different. I made the changes, hopefully this is more consistent.

x_fake = list(six.iterkeys(self.data))[0]

z_true = list(six.itervalues(self.data))[1]
z_fake = list(six.iterkeys(self.data))[1]

with tf.variable_scope("Disc"):
# xtzf := x_true, z_fake
d_xtzf = self.discriminator(x_true, z_fake)
with tf.variable_scope("Disc", reuse=True):
# xfzt := x_fake, z_true
d_xfzt = self.discriminator(x_fake, z_true)

loss_d = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(d_xfzt), logits=d_xfzt) + \
tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(d_xtzf), logits=d_xtzf)
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(d_xfzt), logits=d_xfzt) + \
tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(d_xtzf), logits=d_xtzf)

loss_d = tf.reduce_mean(loss_d)
loss = tf.reduce_mean(loss)

var_list_d = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc")
var_list = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Gen")

grads_d = tf.gradients(loss_d, var_list_d)
grads = tf.gradients(loss, var_list)
grads_and_vars_d = list(zip(grads_d, var_list_d))
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars, loss_d, grads_and_vars_d
134 changes: 134 additions & 0 deletions examples/bigan_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python
"""Adversarially Learned Inference (Dumoulin et al., 2017) or
Bidirectional Generative Adversarial Networks (Donahue et al., 2017)
for joint learning of generator and inference networks for MNIST
"""
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import edward as ed
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import tensorflow as tf

from tensorflow.contrib import slim
from tensorflow.examples.tutorials.mnist import input_data
from edward.models import Uniform


M = 100 # batch size during training
d = 50 # latent dimension
leak = 0.2 # leak parameter for leakyReLU
hidden_units = 300
encoder_variance = 0.01 # Set to 0 for deterministic encoder


def leakyrelu(x, alpha=leak):
return tf.maximum(x, alpha * x)


def gen_latent(x, hidden_units):
h = slim.fully_connected(x, hidden_units, activation_fn=leakyrelu)
z = slim.fully_connected(h, d, activation_fn=None)
return z + np.random.normal(0, encoder_variance, np.shape(z))


def gen_data(z, hidden_units):
h = slim.fully_connected(z, hidden_units, activation_fn=leakyrelu)
x = slim.fully_connected(h, 784, activation_fn=tf.sigmoid)
return x


def discriminative_network(x, y):
# Discriminator must output probability in logits
inputs = tf.concat([x, y], 1)
h1 = slim.fully_connected(inputs, hidden_units, activation_fn=leakyrelu)
logit = slim.fully_connected(h1, 1, activation_fn=None)
return logit


def plot(samples):
fig = plt.figure(figsize=(4, 4))
plt.title(str(samples))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)

for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

return fig


ed.set_seed(42)

DATA_DIR = "data/mnist"
IMG_DIR = "img"

if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
if not os.path.exists(IMG_DIR):
os.makedirs(IMG_DIR)

# DATA. MNIST batches are fed at training time.
mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)
x_ph = tf.placeholder(tf.float32, [M, 784])
z_ph = tf.placeholder(tf.float32, [M, d])

# MODEL
with tf.variable_scope("Gen"):
xf = gen_data(z_ph, hidden_units)
zf = gen_latent(x_ph, hidden_units)

# INFERENCE:
optimizer = tf.train.AdamOptimizer()
optimizer_d = tf.train.AdamOptimizer()
inference = ed.BiGANInference(
latent_vars={zf: z_ph}, data = {xf: x_ph}, discriminator=discriminative_network)

inference.initialize(
optimizer=optimizer, optimizer_d=optimizer_d, n_iter=100000, n_print=3000)

sess = ed.get_session()
init_op = tf.global_variables_initializer()
sess.run(init_op)

idx = np.random.randint(M, size=16)
i = 0
for t in range(inference.n_iter):
if t % inference.n_print == 1:

samples = sess.run(xf, feed_dict={z_ph: z_batch})
samples = samples[idx, ]
fig = plot(samples)
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format(
'Generated', str(i).zfill(3)), bbox_inches='tight')
plt.close(fig)

fig = plot(x_batch[idx, ])
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format(
'Base', str(i).zfill(3)), bbox_inches='tight')
plt.close(fig)

zsam = sess.run(zf, feed_dict={x_ph: x_batch})
reconstructions = sess.run(xf, feed_dict={z_ph: zsam})
reconstructions = reconstructions[idx, ]
fig = plot(reconstructions)
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format(
'Reconstruct', str(i).zfill(3)), bbox_inches='tight')
plt.close(fig)

i += 1

x_batch, _ = mnist.train.next_batch(M)
z_batch = np.random.normal(0, 1, [M, d])

info_dict = inference.update(feed_dict={x_ph: x_batch, z_ph: z_batch})
inference.print_progress(info_dict)