Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ali #597

Merged
merged 11 commits into from
Apr 15, 2017
86 changes: 86 additions & 0 deletions edward/inferences/ali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import tensorflow as tf

from edward.inferences.gan_inference import GANInference
from edward.util import get_session


class ALI(GANInference):
"""Adversarially Learned Inference (Dumoulin et al., 2016) or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both were published at ICLR 2017, so you can update the citation with ... et al., 2017.

Bidirectional Generative Adversarial Networks (Donahue et al., 2016)
for joint learning of generator and inference networks.

Works for the class of implicit (and differentiable) probabilistic
models. These models do not require a tractable density and assume
only a program that generates samples.
"""
def __init__(self, *args, **kwargs):
"""
Notes
-----
``ALI`` matches a mapping from data to latent variables and a
mapping from latent variables to data through a joint
discriminator. The encoder approximates the posterior p(z|x)
when the network is stochastic.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is untrue. The terminology "Adversarially learned inference" is a misnomer because the distribution p( z | x) in GANs is degenerate. More accurately would be to say

Using a joint discriminator, ``ALI`` learns a mapping from noise to data 
as well as an inverse mapping from data to noise. The inference network 
is the inverse mapping.

We discuss this misconception in my paper on deep implicit models (see the noise vs latent variables part of section 2.1).

I personally prefer the other paper which calls it "adversarial feature learning"/BiGANs; to me this description is more accurate. So I'd be for BiGANInference, but if you feel strongly about ALI, that's okay too.


In building the computation graph for inference, the
discriminator's parameters can be accessed with the variable scope
"Disc".
In building the computation graph for inference, the
encoder and decoder parameters can be accessed with the variable scope
"Gen".

Examples
--------
>>> with tf.variable_scope("Gen"):
>>> xf = gen_data(z_ph)
>>> zf = gen_latent(x_ph)
>>> inference = ed.ALI({xf: x_data, zf: z_samples}, discriminator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match the interface of other inference methods, what about

ed.ALI(latent_vars, data, discriminative_network)  # signature
ed.ALI({zf: qz}, {xf: x_data}, discriminator)  # example

Here the model's noise zf is binded to the inference network output qz, and the model's generated output xf is binded to data x_data.

"""
super(ALI, self).__init__(*args, **kwargs)

def initialize(self, *args, **kwargs):
super(ALI, self).initialize(*args, **kwargs)

def build_loss_and_gradients(self, var_list):
x_true = list(six.itervalues(self.data))[0]
x_fake = list(six.iterkeys(self.data))[0]
z_true = list(six.itervalues(self.data))[1]
z_fake = list(six.iterkeys(self.data))[1]
with tf.variable_scope("Disc"):
# xfzt := x_fake, z_true
d_xtzf = self.discriminator(x_true, z_fake)
with tf.variable_scope("Disc", reuse=True):
# xfzt := x_fake, z_true
d_xfzt = self.discriminator(x_fake, z_true)

loss_d = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(d_xfzt), logits=d_xfzt) + \
tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(d_xtzf), logits=d_xtzf)
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(d_xfzt), logits=d_xfzt) + \
tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(d_xtzf), logits=d_xtzf)

loss_d = tf.reduce_mean(loss_d)
loss = tf.reduce_mean(loss)

var_list_d = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc")
var_list = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Gen")

grads_d = tf.gradients(loss_d, var_list_d)
grads = tf.gradients(loss, var_list)
grads_and_vars_d = list(zip(grads_d, var_list_d))
grads_and_vars = list(zip(grads, var_list))
return loss, grads_and_vars, loss_d, grads_and_vars_d

def update(self, feed_dict=None, variables=None):
info_dict = super(ALI, self).update(feed_dict, variables)
return info_dict
133 changes: 133 additions & 0 deletions examples/ali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#!/usr/bin/env python
"""Adversarially Learned Inference (Dumoulin et al., 2016) or
Bidirectional Generative Adversarial Networks (Donahue et al., 2016)
for joint learning of generator and inference networks for MNIST
"""
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import edward as ed
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import os
import tensorflow as tf

from tensorflow.contrib import slim
from tensorflow.examples.tutorials.mnist import input_data
from edward.models import Uniform


M = 100 # batch size during training
d = 50 # latent dimension
leak = 0.2 # leak parameter for leakyReLU
hidden_units = 300
encoder_variance = 0.01 # Set to 0 for deterministic encoder


def leakyrelu(x, alpha=leak):
return tf.maximum(x, alpha * x)


def gen_latent(x, hidden_units):
h = slim.fully_connected(x, hidden_units, activation_fn=leakyrelu)
z = slim.fully_connected(h, d, activation_fn=None)
return z + np.random.normal(0, encoder_variance, np.shape(z))


def gen_data(z, hidden_units):
h = slim.fully_connected(z, hidden_units, activation_fn=leakyrelu)
x = slim.fully_connected(h, 784, activation_fn=tf.sigmoid)
return x


def discriminative_network(x, y):
# Discriminator must output probability in logits
inputs = tf.concat([x, y], 1)
h1 = slim.fully_connected(inputs, hidden_units, activation_fn=leakyrelu)
logit = slim.fully_connected(h1, 1, activation_fn=None)
return logit


def plot(samples):
fig = plt.figure(figsize=(4, 4))
plt.title(str(samples))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)

for i, sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

return fig


ed.set_seed(42)

DATA_DIR = "data/mnist"
IMG_DIR = "img"

if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
if not os.path.exists(IMG_DIR):
os.makedirs(IMG_DIR)

# DATA. MNIST batches are fed at training time.
mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)
x_ph = tf.placeholder(tf.float32, [M, 784])
z_ph = tf.placeholder(tf.float32, [M, d])

# MODEL
with tf.variable_scope("Gen"):
xf = gen_data(z_ph, hidden_units)
zf = gen_latent(x_ph, hidden_units)

# INFERENCE:
optimizer = tf.train.AdamOptimizer()
optimizer_d = tf.train.AdamOptimizer()
inference = ed.ALI(
data={xf: x_ph, zf: z_ph}, discriminator=discriminative_network)
inference.initialize(
optimizer=optimizer, optimizer_d=optimizer_d, n_iter=100000, n_print=3000)

sess = ed.get_session()
init_op = tf.global_variables_initializer()
sess.run(init_op)

idx = np.random.randint(M, size=16)
i = 0
for t in range(inference.n_iter):
if t % inference.n_print == 1:

samples = sess.run(xf, feed_dict={z_ph: z_batch})
samples = samples[idx, ]
fig = plot(samples)
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format(
'Generated', str(i).zfill(3)), bbox_inches='tight')
plt.close(fig)

fig = plot(x_batch[idx, ])
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format(
'Base', str(i).zfill(3)), bbox_inches='tight')
plt.close(fig)

zsam = sess.run(zf, feed_dict={x_ph: x_batch})
reconstructions = sess.run(xf, feed_dict={z_ph: zsam})
reconstructions = reconstructions[idx, ]
fig = plot(reconstructions)
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format(
'Reconstruct', str(i).zfill(3)), bbox_inches='tight')
plt.close(fig)

i += 1

x_batch, _ = mnist.train.next_batch(M)
z_batch = np.random.normal(0, 1, [M, d])

info_dict = inference.update(feed_dict={x_ph: x_batch, z_ph: z_batch})
inference.print_progress(info_dict)