-
Notifications
You must be signed in to change notification settings - Fork 759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/ali #597
Feature/ali #597
Changes from 10 commits
3080f88
9bb7663
d6378c7
141ab75
30c5ef6
9f85cf5
8a47fe1
7f7f542
d889e97
7fe4020
828b484
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import six | ||
import tensorflow as tf | ||
|
||
from edward.inferences.gan_inference import GANInference | ||
from edward.util import get_session | ||
|
||
|
||
class BiGANInference(GANInference): | ||
"""Adversarially Learned Inference (Dumoulin et al., 2017) or | ||
Bidirectional Generative Adversarial Networks (Donahue et al., 2017) | ||
for joint learning of generator and inference networks. | ||
|
||
Works for the class of implicit (and differentiable) probabilistic | ||
models. These models do not require a tractable density and assume | ||
only a program that generates samples. | ||
""" | ||
def __init__(self, latent_vars, data, discriminator): | ||
""" | ||
Notes | ||
----- | ||
``BiGANInference`` matches a mapping from data to latent variables and a | ||
mapping from latent variables to data through a joint | ||
discriminator. | ||
|
||
In building the computation graph for inference, the | ||
discriminator's parameters can be accessed with the variable scope | ||
"Disc". | ||
In building the computation graph for inference, the | ||
encoder and decoder parameters can be accessed with the variable scope | ||
"Gen". | ||
|
||
Examples | ||
-------- | ||
>>> with tf.variable_scope("Gen"): | ||
>>> xf = gen_data(z_ph) | ||
>>> zf = gen_latent(x_ph) | ||
>>> inference = ed.BiGANInference({latent_vars: qz}, {xf: x_data}, discriminator) | ||
""" | ||
joint = latent_vars.copy() | ||
joint.update(data) | ||
|
||
super(BiGANInference, self).__init__(joint, discriminator) | ||
|
||
def build_loss_and_gradients(self, var_list): | ||
|
||
x_true = list(six.itervalues(self.data))[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With a dictionary, if not callable(discriminator):
raise TypeError("discriminator must be a callable function.")
self.discriminator = discriminator
# call grandparent's method; avoid parent (GANInference)
super(GANInference, self).__init__(latent_vars, data) In addition, for consistency with other inferences, the z_true = list(six.iterkeys(self.latent_vars))[0]
z_fake = list(six.itervalues(self.latent_vars))[0] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's my mistake, I didn't realize that the order was different. I made the changes, hopefully this is more consistent. |
||
x_fake = list(six.iterkeys(self.data))[0] | ||
|
||
z_true = list(six.itervalues(self.data))[1] | ||
z_fake = list(six.iterkeys(self.data))[1] | ||
|
||
with tf.variable_scope("Disc"): | ||
# xtzf := x_true, z_fake | ||
d_xtzf = self.discriminator(x_true, z_fake) | ||
with tf.variable_scope("Disc", reuse=True): | ||
# xfzt := x_fake, z_true | ||
d_xfzt = self.discriminator(x_fake, z_true) | ||
|
||
loss_d = tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.ones_like(d_xfzt), logits=d_xfzt) + \ | ||
tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.zeros_like(d_xtzf), logits=d_xtzf) | ||
loss = tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.zeros_like(d_xfzt), logits=d_xfzt) + \ | ||
tf.nn.sigmoid_cross_entropy_with_logits( | ||
labels=tf.ones_like(d_xtzf), logits=d_xtzf) | ||
|
||
loss_d = tf.reduce_mean(loss_d) | ||
loss = tf.reduce_mean(loss) | ||
|
||
var_list_d = tf.get_collection( | ||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Disc") | ||
var_list = tf.get_collection( | ||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="Gen") | ||
|
||
grads_d = tf.gradients(loss_d, var_list_d) | ||
grads = tf.gradients(loss, var_list) | ||
grads_and_vars_d = list(zip(grads_d, var_list_d)) | ||
grads_and_vars = list(zip(grads, var_list)) | ||
return loss, grads_and_vars, loss_d, grads_and_vars_d |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
#!/usr/bin/env python | ||
"""Adversarially Learned Inference (Dumoulin et al., 2017) or | ||
Bidirectional Generative Adversarial Networks (Donahue et al., 2017) | ||
for joint learning of generator and inference networks for MNIST | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import print_function | ||
from __future__ import division | ||
|
||
import edward as ed | ||
import matplotlib.pyplot as plt | ||
import matplotlib.gridspec as gridspec | ||
import numpy as np | ||
import os | ||
import tensorflow as tf | ||
|
||
from tensorflow.contrib import slim | ||
from tensorflow.examples.tutorials.mnist import input_data | ||
from edward.models import Uniform | ||
|
||
|
||
M = 100 # batch size during training | ||
d = 50 # latent dimension | ||
leak = 0.2 # leak parameter for leakyReLU | ||
hidden_units = 300 | ||
encoder_variance = 0.01 # Set to 0 for deterministic encoder | ||
|
||
|
||
def leakyrelu(x, alpha=leak): | ||
return tf.maximum(x, alpha * x) | ||
|
||
|
||
def gen_latent(x, hidden_units): | ||
h = slim.fully_connected(x, hidden_units, activation_fn=leakyrelu) | ||
z = slim.fully_connected(h, d, activation_fn=None) | ||
return z + np.random.normal(0, encoder_variance, np.shape(z)) | ||
|
||
|
||
def gen_data(z, hidden_units): | ||
h = slim.fully_connected(z, hidden_units, activation_fn=leakyrelu) | ||
x = slim.fully_connected(h, 784, activation_fn=tf.sigmoid) | ||
return x | ||
|
||
|
||
def discriminative_network(x, y): | ||
# Discriminator must output probability in logits | ||
inputs = tf.concat([x, y], 1) | ||
h1 = slim.fully_connected(inputs, hidden_units, activation_fn=leakyrelu) | ||
logit = slim.fully_connected(h1, 1, activation_fn=None) | ||
return logit | ||
|
||
|
||
def plot(samples): | ||
fig = plt.figure(figsize=(4, 4)) | ||
plt.title(str(samples)) | ||
gs = gridspec.GridSpec(4, 4) | ||
gs.update(wspace=0.05, hspace=0.05) | ||
|
||
for i, sample in enumerate(samples): | ||
ax = plt.subplot(gs[i]) | ||
plt.axis('off') | ||
ax.set_xticklabels([]) | ||
ax.set_yticklabels([]) | ||
ax.set_aspect('equal') | ||
plt.imshow(sample.reshape(28, 28), cmap='Greys_r') | ||
|
||
return fig | ||
|
||
|
||
ed.set_seed(42) | ||
|
||
DATA_DIR = "data/mnist" | ||
IMG_DIR = "img" | ||
|
||
if not os.path.exists(DATA_DIR): | ||
os.makedirs(DATA_DIR) | ||
if not os.path.exists(IMG_DIR): | ||
os.makedirs(IMG_DIR) | ||
|
||
# DATA. MNIST batches are fed at training time. | ||
mnist = input_data.read_data_sets(DATA_DIR, one_hot=True) | ||
x_ph = tf.placeholder(tf.float32, [M, 784]) | ||
z_ph = tf.placeholder(tf.float32, [M, d]) | ||
|
||
# MODEL | ||
with tf.variable_scope("Gen"): | ||
xf = gen_data(z_ph, hidden_units) | ||
zf = gen_latent(x_ph, hidden_units) | ||
|
||
# INFERENCE: | ||
optimizer = tf.train.AdamOptimizer() | ||
optimizer_d = tf.train.AdamOptimizer() | ||
inference = ed.BiGANInference( | ||
latent_vars={zf: z_ph}, data = {xf: x_ph}, discriminator=discriminative_network) | ||
|
||
inference.initialize( | ||
optimizer=optimizer, optimizer_d=optimizer_d, n_iter=100000, n_print=3000) | ||
|
||
sess = ed.get_session() | ||
init_op = tf.global_variables_initializer() | ||
sess.run(init_op) | ||
|
||
idx = np.random.randint(M, size=16) | ||
i = 0 | ||
for t in range(inference.n_iter): | ||
if t % inference.n_print == 1: | ||
|
||
samples = sess.run(xf, feed_dict={z_ph: z_batch}) | ||
samples = samples[idx, ] | ||
fig = plot(samples) | ||
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format( | ||
'Generated', str(i).zfill(3)), bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
fig = plot(x_batch[idx, ]) | ||
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format( | ||
'Base', str(i).zfill(3)), bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
zsam = sess.run(zf, feed_dict={x_ph: x_batch}) | ||
reconstructions = sess.run(xf, feed_dict={z_ph: zsam}) | ||
reconstructions = reconstructions[idx, ] | ||
fig = plot(reconstructions) | ||
plt.savefig(os.path.join(IMG_DIR, '{}{}.png').format( | ||
'Reconstruct', str(i).zfill(3)), bbox_inches='tight') | ||
plt.close(fig) | ||
|
||
i += 1 | ||
|
||
x_batch, _ = mnist.train.next_batch(M) | ||
z_batch = np.random.normal(0, 1, [M, d]) | ||
|
||
info_dict = inference.update(feed_dict={x_ph: x_batch, z_ph: z_batch}) | ||
inference.print_progress(info_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this line be
{z_ph: zf}
and{xf: x_ph}
?