-
Notifications
You must be signed in to change notification settings - Fork 759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/ali #597
Feature/ali #597
Changes from 4 commits
3080f88
9bb7663
d6378c7
141ab75
30c5ef6
9f85cf5
8a47fe1
7f7f542
d889e97
7fe4020
828b484
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import six | ||
import tensorflow as tf | ||
|
||
from edward.inferences.gan_inference import GANInference | ||
from edward.util import get_session | ||
|
||
|
||
class ALI(GANInference): | ||
"""Adversarially Learned Inference (Dumoulin et al., 2016) or | ||
Bidirectional Generative Adversarial Networks (Donahue et al., 2016) | ||
for joint learning of generator and inference networks. | ||
|
||
Works for the class of implicit (and differentiable) probabilistic | ||
models. These models do not require a tractable density and assume | ||
only a program that generates samples. | ||
""" | ||
def __init__(self, *args, **kwargs): | ||
""" | ||
Notes | ||
----- | ||
``ALI`` matches a mapping from data to latent variables and a | ||
mapping from latent variables to data through a joint | ||
discriminator. The encoder approximates the posterior p(z|x) | ||
when the network is stochastic. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually this is untrue. The terminology "Adversarially learned inference" is a misnomer because the distribution p( z | x) in GANs is degenerate. More accurately would be to say
We discuss this misconception in my paper on deep implicit models (see the noise vs latent variables part of section 2.1). I personally prefer the other paper which calls it "adversarial feature learning"/BiGANs; to me this description is more accurate. So I'd be for |
||
|
||
In building the computation graph for inference, the | ||
discriminator's parameters can be accessed with the variable scope | ||
"Disc". | ||
In building the computation graph for inference, the | ||
encoder and decoder parameters can be accessed with the variable scope | ||
"Gen". | ||
|
||
Examples | ||
-------- | ||
>>> with tf.variable_scope("Gen"): | ||
>>> xf = gen_data(z_ph) | ||
>>> zf = gen_latent(x_ph) | ||
>>> inference = ed.ALI({xf: x_data, zf: z_samples}, discriminator) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To match the interface of other inference methods, what about ed.ALI(latent_vars, data, discriminative_network) # signature
ed.ALI({zf: qz}, {xf: x_data}, discriminator) # example Here the model's noise |
||
""" | ||
super(ALI, self).__init__(*args, **kwargs) | ||
|
||
def initialize(self, *args, **kwargs): | ||
super(ALI, self).initialize(*args, **kwargs) | ||
|
||
def build_loss_and_gradients(self, var_list): | ||
x_true = list(six.itervalues(self.data))[0] | ||
x_fake = list(six.iterkeys(self.data))[0] | ||
z_true = list(six.itervalues(self.data))[1] | ||
z_fake = list(six.iterkeys(self.data))[1] | ||
with tf.variable_scope("Disc"): | ||
# xfzt := x_fake, z_true | ||
d_xtzf = self.discriminator(x_true, z_fake) | ||
with tf.variable_scope("Disc", reuse=True): | ||
# xfzt := x_fake, z_true | ||
d_xfzt = self.discriminator(x_fake, z_true) | ||
|
||
loss_d = tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.ones_like(d_xfzt), logits=d_xfzt) + \ | ||
tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.zeros_like(d_xtzf), logits=d_xtzf) | ||
loss = tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.zeros_like(d_xfzt), logits=d_xfzt) + \ | ||
tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.ones_like(d_xtzf), logits=d_xtzf) | ||
|
||
loss_d = tf.reduce_mean(loss_d) | ||
loss = tf.reduce_mean(loss) | ||
|
||
var_list_d = tf.get_collection( | ||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") | ||
var_list = tf.get_collection( | ||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Gen") | ||
|
||
grads_d = tf.gradients(loss_d, var_list_d) | ||
grads = tf.gradients(loss, var_list) | ||
grads_and_vars_d = list(zip(grads_d, var_list_d)) | ||
grads_and_vars = list(zip(grads, var_list)) | ||
return loss, grads_and_vars, loss_d, grads_and_vars_d | ||
|
||
def update(self, feed_dict=None, variables=None): | ||
info_dict = super(ALI, self).update(feed_dict, variables) | ||
return info_dict |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
#!/usr/bin/env python | ||
"""Adversarially Learned Inference (Dumoulin et al., 2016) or | ||
Bidirectional Generative Adversarial Networks (Donahue et al., 2016) | ||
for joint learning of generator and inference networks for MNIST | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import print_function | ||
from __future__ import division | ||
|
||
import edward as ed | ||
import matplotlib.pyplot as plt | ||
import matplotlib.gridspec as gridspec | ||
import numpy as np | ||
import os | ||
import tensorflow as tf | ||
|
||
from tensorflow.contrib import slim | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
from edward.models import Uniform | ||
|
||
|
||
M = 100 # batch size during training | ||
d = 50 # latent dimension | ||
leak = 0.2 # leak parameter for leakyReLU | ||
hidden_units = 300 | ||
encoder_variance = 0.01 # Set to 0 for deterministic encoder | ||
|
||
|
||
def leakyrelu(x, alpha=leak): | ||
return tf.maximum(x, alpha * x) | ||
|
||
|
||
def gen_latent(x, hidden_units): | ||
h = slim.fully_connected(x, hidden_units, activation_fn=leakyrelu) | ||
z = slim.fully_connected(h, d, activation_fn=None) | ||
return z + np.random.normal(0, encoder_variance, np.shape(z)) | ||
|
||
|
||
def gen_data(z, hidden_units): | ||
h = slim.fully_connected(z, hidden_units, activation_fn=leakyrelu) | ||
x = slim.fully_connected(h, 784, activation_fn=tf.sigmoid) | ||
return x | ||
|
||
|
||
def discriminative_network(x, y): | ||
# Discriminator must output probability in logits | ||
inputs = tf.concat([x, y], 1) | ||
h1 = slim.fully_connected(inputs, hidden_units, activation_fn=leakyrelu) | ||
logit = slim.fully_connected(h1, 1, activation_fn=None) | ||
return logit | ||
|
||
|
||
def plot(samples): | ||
fig = plt.figure(figsize=(4, 4)) | ||
plt.title(str(samples)) | ||
gs = gridspec.GridSpec(4, 4) | ||
gs.update(wspace=0.05, hspace=0.05) | ||
|
||
for i, sample in enumerate(samples): | ||
ax = plt.subplot(gs[i]) | ||
plt.axis('off') | ||
ax.set_xticklabels([]) | ||
ax.set_yticklabels([]) | ||
ax.set_aspect('equal') | ||
plt.imshow(sample.reshape(28, 28), cmap='Greys_r') | ||
|
||
return fig | ||
|
||
|
||
ed.set_seed(42) | ||
|
||
DATA_DIR = "data/mnist" | ||
IMG_DIR = "img" | ||
|
||
if not os.path.exists(DATA_DIR): | ||
os.makedirs(DATA_DIR) | ||
if not os.path.exists(IMG_DIR): | ||
os.makedirs(IMG_DIR) | ||
|
||
# DATA. MNIST batches are fed at training time. | ||
mnist = input_data.read_data_sets(DATA_DIR, one_hot=True) | ||
x_ph = tf.placeholder(tf.float32, [M, 784]) | ||
z_ph = tf.placeholder(tf.float32, [M, d]) | ||
|
||
# MODEL | ||
with tf.variable_scope("Gen"): | ||
xf = gen_data(z_ph, hidden_units) | ||
zf = gen_latent(x_ph, hidden_units) | ||
|
||
# INFERENCE: | ||
optimizer = tf.train.AdamOptimizer() | ||
optimizer_d = tf.train.AdamOptimizer() | ||
inference = ed.ALI( | ||
data={xf: x_ph, zf: z_ph}, discriminator=discriminative_network) | ||
inference.initialize( | ||
optimizer=optimizer, optimizer_d=optimizer_d, n_iter=100000, n_print=3000) | ||
|
||
sess = ed.get_session() | ||
init_op = tf.global_variables_initializer() | ||
sess.run(init_op) | ||
|
||
idx = np.random.randint(M, size=16) | ||
i = 0 | ||
for t in range(inference.n_iter): | ||
if t % inference.n_print == 1: | ||
|
||
samples = sess.run(xf, feed_dict={z_ph: z_batch}) | ||
samples = samples[idx, ] | ||
fig = plot(samples) | ||
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format( | ||
'Generated', str(i).zfill(3)), bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
fig = plot(x_batch[idx, ]) | ||
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format( | ||
'Base', str(i).zfill(3)), bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
zsam = sess.run(zf, feed_dict={x_ph: x_batch}) | ||
reconstructions = sess.run(xf, feed_dict={z_ph: zsam}) | ||
reconstructions = reconstructions[idx, ] | ||
fig = plot(reconstructions) | ||
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format( | ||
'Reconstruct', str(i).zfill(3)), bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
i += 1 | ||
|
||
x_batch, _ = mnist.train.next_batch(M) | ||
z_batch = np.random.normal(0, 1, [M, d]) | ||
|
||
info_dict = inference.update(feed_dict={x_ph: x_batch, z_ph: z_batch}) | ||
inference.print_progress(info_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both were published at ICLR 2017, so you can update the citation with ... et al., 2017.