Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Added implementation of SGHMC and example. #415

Merged
merged 4 commits into from
Jan 28, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/tex/api/inference-classes.tex
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ \subsubsection{Exact Inference}
:members:

.. autoclass:: edward.inferences.SGLD
:members:
:members:

.. autoclass:: edward.inferences.SGHMC
:members:

}}

Expand Down
2 changes: 1 addition & 1 deletion edward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Direct imports for convenience
from edward.criticisms import evaluate, ppc
from edward.inferences import Inference, MonteCarlo, VariationalInference, \
HMC, MetropolisHastings, SGLD, \
HMC, MetropolisHastings, SGLD, SGHMC, \
KLpq, KLqp, MFVI, ReparameterizationKLqp, ReparameterizationKLKLqp, \
ReparameterizationEntropyKLqp, ScoreKLqp, ScoreKLKLqp, ScoreEntropyKLqp, \
MAP, Laplace
Expand Down
1 change: 1 addition & 0 deletions edward/inferences/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
from edward.inferences.metropolis_hastings import *
from edward.inferences.monte_carlo import *
from edward.inferences.sgld import *
from edward.inferences.sghmc import *
from edward.inferences.variational_inference import *
149 changes: 149 additions & 0 deletions edward/inferences/sghmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
import tensorflow as tf

from edward.inferences.monte_carlo import MonteCarlo
from edward.models import Normal, RandomVariable, Empirical
from edward.util import copy


class SGHMC(MonteCarlo):
"""Stochastic gradient Hamiltonian Monte Carlo (Chen et al., 2014).

Notes
-----
In conditional inference, we infer :math:`z` in :math:`p(z, \\beta
\mid x)` while fixing inference over :math:`\\beta` using another
distribution :math:`q(\\beta)`.
``SGHMC`` substitutes the model's log marginal density

.. math::

\log p(x, z) = \log \mathbb{E}_{q(\\beta)} [ p(x, z, \\beta) ]
\\approx \log p(x, z, \\beta^*)

leveraging a single Monte Carlo sample, where :math:`\\beta^* \sim
q(\\beta)`. This is unbiased (and therefore asymptotically exact as a
pseudo-marginal method) if :math:`q(\\beta) = p(\\beta \mid x)`.
"""
def __init__(self, *args, **kwargs):
"""
Examples
--------
>>> z = Normal(mu=0.0, sigma=1.0)
>>> x = Normal(mu=tf.ones(10) * z, sigma=1.0)
>>>
>>> qz = Empirical(tf.Variable(tf.zeros([500])))
>>> data = {x: np.array([0.0] * 10, dtype=np.float32)}
>>> inference = ed.SGHMC({z: qz}, data)
"""
super(SGHMC, self).__init__(*args, **kwargs)

def initialize(self, step_size=0.25, friction=0.1, *args, **kwargs):
"""
Parameters
----------
step_size : float, optional
Constant scale factor of learning rate.
friction : float, optional
Constant scale on the friction term in the Hamiltonian system.
"""
self.step_size = step_size
self.friction = friction
self.v = {z: tf.Variable(tf.zeros(qz.params.get_shape()[1:]))
for z, qz in six.iteritems(self.latent_vars)}
return super(SGHMC, self).initialize(*args, **kwargs)

def build_update(self):
"""
Simulate Hamiltonian dynamics with friction using a discretized
integrator. Its discretization error goes to zero as the learning rate
decreases.
Implements the update equations from (15) of Chen et al., 2014.
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}
old_v_sample = {z: v for z, v in six.iteritems(self.v)}

# Simulate Hamiltonian dynamics with friction.
friction = tf.constant(self.friction, dtype=tf.float32)
learning_rate = tf.constant(self.step_size * 0.01, dtype=tf.float32)
grad_log_joint = tf.gradients(self._log_joint(old_sample),
list(six.itervalues(old_sample)))

# v_sample is so named b/c it represents a velocity rather than momentum.
sample = {} # v_sample
v_sample = {} # rather than a momentum.
for z, qz, grad_log_p in \
zip(six.iterkeys(self.latent_vars),
six.itervalues(self.latent_vars),
grad_log_joint):
event_shape = qz.get_event_shape()
normal = Normal(mu=tf.zeros(event_shape),
sigma=(tf.sqrt(learning_rate * friction) *
tf.ones(event_shape)))
sample[z] = old_sample[z] + old_v_sample[z]
v_sample[z] = ((1. - 0.5 * friction) * old_v_sample[z] +
learning_rate * grad_log_p + normal.sample())

# Update Empirical random variables.
assign_ops = []
variables = {x.name: x for x in
tf.get_default_graph().get_collection(tf.GraphKeys.VARIABLES)}
for z, qz in six.iteritems(self.latent_vars):
variable = variables[qz.params.op.inputs[0].op.inputs[0].name]
assign_ops.append(tf.scatter_update(variable, self.t, sample[z]))
assign_ops.append(tf.assign(self.v[z], v_sample[z]).op)

# Increment n_accept.
assign_ops.append(self.n_accept.assign_add(1))
return tf.group(*assign_ops)

def _log_joint(self, z_sample):
"""
Utility function to calculate model's log joint density,
log p(x, z), for inputs z (and fixed data x).

Parameters
----------
z_sample : dict
Latent variable keys to samples.
"""
if self.model_wrapper is None:
scope = 'inference_' + str(id(self))
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
dict_swap = z_sample.copy()
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

log_joint = 0.0
for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
z_log_prob = tf.reduce_sum(z_copy.log_prob(dict_swap[z]))
if z in self.scale:
z_log_prob *= self.scale[z]

log_joint += z_log_prob

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
x_log_prob = tf.reduce_sum(x_copy.log_prob(dict_swap[x]))
if x in self.scale:
x_log_prob *= self.scale[x]

log_joint += x_log_prob
else:
x = self.data
log_joint = self.model_wrapper.log_prob(x, z_sample)

return log_joint
97 changes: 97 additions & 0 deletions examples/bayesian_linear_regression_sghmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python
"""Bayesian linear regression using variational inference.

This version visualizes additional fits of the model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import tensorflow as tf

from edward.models import Normal, Empirical


def build_toy_dataset(N, noise_std=0.5):
X = np.concatenate([np.linspace(0, 2, num=N / 2),
np.linspace(6, 8, num=N / 2)])
y = 2.0 * X + 10 * np.random.normal(0, noise_std, size=N)
X = X.astype(np.float32).reshape((N, 1))
y = y.astype(np.float32)
return X, y


ed.set_seed(42)

N = 40 # number of data points
D = 1 # number of features

# DATA
X_train, y_train = build_toy_dataset(N)
X_test, y_test = build_toy_dataset(N)

# MODEL
X = tf.placeholder(tf.float32, [N, D])
w = Normal(mu=tf.zeros(D), sigma=tf.ones(D))
b = Normal(mu=tf.zeros(1), sigma=tf.ones(1))
y = Normal(mu=ed.dot(X, w) + b, sigma=tf.ones(N))

# INFERENCE
T = 5000 # Number of samples.
nburn = 100 # Number of burn-in samples.
stride = 10 # Frequency with which to plot samples.
qw = Empirical(params=tf.Variable(tf.random_normal([T, D])))
qb = Empirical(params=tf.Variable(tf.random_normal([T, 1])))

inference = ed.SGHMC({w: qw, b: qb}, data={X: X_train, y: y_train})
inference.run(step_size=1e-3)


# CRITICISM

# Plot posterior samples.
sns.jointplot(qb.params.eval()[nburn:T:stride],
qw.params.eval()[nburn:T:stride])
plt.show()

# Posterior predictive checks.
y_post = ed.copy(y, {w: qw.mean(), b: qb.mean()})
# This is equivalent to
# y_post = Normal(mu=ed.dot(X, qw.mean()) + qb.mean(), sigma=tf.ones(N))

print("Mean squared error on test data:")
print(ed.evaluate('mean_squared_error', data={X: X_test, y_post: y_test}))

print("Displaying prior predictive samples.")
n_prior_samples = 10

w_prior = w.sample(n_prior_samples).eval()
b_prior = b.sample(n_prior_samples).eval()

plt.scatter(X_train, y_train)

inputs = np.linspace(-1, 10, num=400, dtype=np.float32)
for ns in range(n_prior_samples):
output = inputs * w_prior[ns] + b_prior[ns]
plt.plot(inputs, output)

plt.show()

print("Displaying posterior predictive samples.")
n_posterior_samples = 10

w_post = qw.sample(n_posterior_samples).eval()
b_post = qb.sample(n_posterior_samples).eval()

plt.scatter(X_train, y_train)

inputs = np.linspace(-1, 10, num=400, dtype=np.float32)
for ns in range(n_posterior_samples):
output = inputs * w_post[ns] + b_post[ns]
plt.plot(inputs, output)

plt.show()
70 changes: 70 additions & 0 deletions examples/normal_sghmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
"""Correlated normal posterior. Inference with stochastic gradient Hamiltonian
Monte Carlo.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import edward as ed
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from edward.models import Empirical, MultivariateNormalFull

plt.style.use("ggplot")

# Plotting helper function.


def mvn_plot_contours(z, label=False, ax=None):
"""
Plot the contours of 2-d Normal or MultivariateNormalFull object.
Scale the axes to show 3 standard deviations.
"""
sess = ed.get_session()
mu = sess.run(z.mu)
mu_x, mu_y = mu
Sigma = sess.run(z.sigma)
sigma_x, sigma_y = np.sqrt(Sigma[0, 0]), np.sqrt(Sigma[1, 1])
xmin, xmax = mu_x - 3 * sigma_x, mu_x + 3 * sigma_x
ymin, ymax = mu_y - 3 * sigma_y, mu_y + 3 * sigma_y
xs = np.linspace(xmin, xmax, num=100)
ys = np.linspace(ymin, ymax, num=100)
X, Y = np.meshgrid(xs, ys)
T = tf.convert_to_tensor(np.c_[X.flatten(), Y.flatten()], dtype=tf.float32)
Z = sess.run(tf.exp(z.log_prob(T))).reshape((len(xs), len(ys)))
if ax is None:
fig, ax = plt.subplots()
cs = ax.contour(X, Y, Z)
if label:
plt.clabel(cs, inline=1, fontsize=10)


# Example body.
ed.set_seed(42)

# MODEL
z = MultivariateNormalFull(mu=tf.ones(2),
sigma=tf.constant([[1.0, 0.8], [0.8, 1.0]]))

# INFERENCE
qz = Empirical(params=tf.Variable(tf.random_normal([2000, 2])))

inference = ed.SGHMC({z: qz})
inference.run(step_size=2e-2)

# CRITICISM
sess = ed.get_session()
mean, std = sess.run([qz.mean(), qz.std()])
print("Inferred posterior mean:")
print(mean)
print("Inferred posterior std:")
print(std)

# VISUALIZATION
fig, ax = plt.subplots()
trace = sess.run(qz.params)
ax.scatter(trace[:, 0], trace[:, 1], marker=".")
mvn_plot_contours(z, ax=ax)
plt.show()