Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly track constrained and unconstrained versions of transformed variables #808

Merged
merged 2 commits into from
Jan 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions edward/inferences/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,43 @@ def build_update(self):
The updates assume each Empirical random variable is directly
parameterized by `tf.Variable`s.
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}

# Gather the initial state, transformed to unconstrained space.
try:
self.latent_vars_unconstrained
except:
raise ValueError("This implementation of HMC requires that all "
"variables have unconstrained support. Please "
"initialize with auto_transform=True to ensure "
"this. (if your variables already have unconstrained "
"support then doing this is a no-op).")
old_sample = {z_unconstrained:
tf.gather(qz_unconstrained.params, tf.maximum(self.t - 1, 0))
for z_unconstrained, qz_unconstrained in
six.iteritems(self.latent_vars_unconstrained)}
old_sample = OrderedDict(old_sample)

# Sample momentum.
old_r_sample = OrderedDict()
for z, qz in six.iteritems(self.latent_vars):
for z, qz in six.iteritems(self.latent_vars_unconstrained):
event_shape = qz.event_shape
normal = Normal(loc=tf.zeros(event_shape, dtype=qz.dtype),
scale=tf.ones(event_shape, dtype=qz.dtype))
old_r_sample[z] = normal.sample()

# Simulate Hamiltonian dynamics.
new_sample, new_r_sample = leapfrog(old_sample, old_r_sample,
self.step_size, self._log_joint,
self.step_size,
self._log_joint_unconstrained,
self.n_steps)

# Calculate acceptance ratio.
ratio = tf.reduce_sum([0.5 * tf.reduce_sum(tf.square(r))
for r in six.itervalues(old_r_sample)])
ratio -= tf.reduce_sum([0.5 * tf.reduce_sum(tf.square(r))
for r in six.itervalues(new_r_sample)])
ratio += self._log_joint(new_sample)
ratio -= self._log_joint(old_sample)
ratio += self._log_joint_unconstrained(new_sample)
ratio -= self._log_joint_unconstrained(old_sample)

# Accept or reject sample.
u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype),
Expand All @@ -108,19 +121,51 @@ def build_update(self):
# `tf.cond` returns tf.Tensor if output is a list of size 1.
sample_values = [sample_values]

sample = {z: sample_value for z, sample_value in
sample = {z_unconstrained: sample_value for
z_unconstrained, sample_value in
zip(six.iterkeys(new_sample), sample_values)}

# Update Empirical random variables.
assign_ops = []
for z, qz in six.iteritems(self.latent_vars):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t, sample[z]))
for z_unconstrained, qz_unconstrained in six.iteritems(
self.latent_vars_unconstrained):
variable = qz_unconstrained.get_variables()[0]
assign_ops.append(tf.scatter_update(
variable, self.t, sample[z_unconstrained]))

# Increment n_accept (if accepted).
assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0)))
return tf.group(*assign_ops)

def _log_joint_unconstrained(self, z_sample):
"""
Given a sample in unconstrained latent space, transform it back into
the original space, and compute the log joint density with appropriate
Jacobian correction.
"""

unconstrained_to_z = {v: k for (k, v) in self.transformations.items()}

# transform all samples back into the original (potentially
# constrained) space.
z_sample_transformed = {}
log_det_jacobian = 0.0
for z_unconstrained, qz_unconstrained in z_sample.items():
z = (unconstrained_to_z[z_unconstrained]
if z_unconstrained in unconstrained_to_z
else z_unconstrained)

try:
bij = self.transformations[z].bijector
z_sample_transformed[z] = bij.inverse(qz_unconstrained)
log_det_jacobian += tf.reduce_sum(
bij.inverse_log_det_jacobian(qz_unconstrained))
except: # if z not in self.transformations,
# or is not a TransformedDist w/ bijector
z_sample_transformed[z] = qz_unconstrained

return self._log_joint(z_sample_transformed) + log_det_jacobian

def _log_joint(self, z_sample):
"""Utility function to calculate model's log joint density,
log p(x, z), for inputs z (and fixed data x).
Expand All @@ -133,6 +178,7 @@ def _log_joint(self, z_sample):
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
dict_swap = z_sample.copy()

for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
Expand Down
44 changes: 33 additions & 11 deletions edward/inferences/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from edward.util import check_data, check_latent_vars, get_session, \
get_variables, Progbar, transform

from tensorflow.contrib.distributions import bijectors

@six.add_metaclass(abc.ABCMeta)
class Inference(object):
Expand Down Expand Up @@ -217,25 +218,46 @@ def initialize(self, n_iter=1000, n_print=None, scale=None,

self.scale = scale

# Set of all latent variables binded to their transformation on
# the unconstrained space (if any).
# map from original latent vars to unconstrained versions
self.transformations = {}
if auto_transform:
latent_vars = self.latent_vars.copy()
self.latent_vars = {}
self.latent_vars = {} # maps original latent vars to constrained Q's
self.latent_vars_unconstrained = {} # maps unconstrained vars to unconstrained Q's
for z, qz in six.iteritems(latent_vars):
if hasattr(z, 'support') and hasattr(qz, 'support') and \
z.support != qz.support and qz.support != 'point':
z_transform = transform(z)
self.transformations[z] = z_transform
if qz.support == 'points': # don't transform empirical approx's
self.latent_vars[z_transform] = qz
z.support != qz.support and qz.support != 'point':

# transform z to an unconstrained space
z_unconstrained = transform(z)
self.transformations[z] = z_unconstrained

# make sure we also have a qz that covers the unconstrained space
if qz.support == "points":
qz_unconstrained = qz
else:
qz_transform = transform(qz)
self.latent_vars[z_transform] = qz_transform
self.transformations[qz] = qz_transform
qz_unconstrained = transform(qz)
self.latent_vars_unconstrained[z_unconstrained] = qz_unconstrained

# additionally construct the transformation of qz
# back into the original constrained space
if z_unconstrained != z:
qz_constrained = transform(
qz_unconstrained, bijectors.Invert(z_unconstrained.bijector))

try: # attempt to pushforward the params of Empirical distributions
qz_constrained.params = z_unconstrained.bijector.inverse(
qz_unconstrained.params)
except: # qz_unconstrained is not an Empirical distribution
pass

else:
qz_constrained = qz_unconstrained

self.latent_vars[z] = qz_constrained
else:
self.latent_vars[z] = qz
self.latent_vars_unconstrained[z] = qz
del latent_vars

if logdir is not None:
Expand Down
66 changes: 59 additions & 7 deletions tests/inferences/test_inference_auto_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tensorflow as tf

from edward.models import (Empirical, Gamma, Normal, PointMass,
TransformedDistribution)
TransformedDistribution, Beta, Bernoulli)
from edward.util import transform
from tensorflow.contrib.distributions import bijectors

Expand Down Expand Up @@ -129,9 +129,9 @@ def test_hmc_custom(self):
# target distribution.
n_samples = 10000
x_unconstrained = inference.transformations[x]
qx_constrained = Empirical(x_unconstrained.bijector.inverse(qx.params))
qx_constrained_params = x_unconstrained.bijector.inverse(qx.params)
x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0)
qx_mean, qx_var = tf.nn.moments(qx_constrained_params[500:], 0)
stats = sess.run([x_mean, qx_mean, x_var, qx_var])
self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1)
self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)
Expand All @@ -152,16 +152,68 @@ def test_hmc_default(self):

# Check approximation on constrained space has same moments as
# target distribution.
n_samples = 10000
x_unconstrained = inference.transformations[x]
qx = inference.latent_vars[x_unconstrained]
qx_constrained = Empirical(x_unconstrained.bijector.inverse(qx.params))
n_samples = 1000
qx_constrained = inference.latent_vars[x]
x_mean, x_var = tf.nn.moments(x.sample(n_samples), 0)
qx_mean, qx_var = tf.nn.moments(qx_constrained.params[500:], 0)
stats = sess.run([x_mean, qx_mean, x_var, qx_var])
self.assertAllClose(stats[0], stats[1], rtol=1e-1, atol=1e-1)
self.assertAllClose(stats[2], stats[3], rtol=1e-1, atol=1e-1)

def test_hmc_betabernoulli(self):
"""Do we correctly handle dependencies of transformed variables?"""

with self.test_session() as sess:
# model
z = Beta(1., 1., name="z")
xs = Bernoulli(probs=z, sample_shape=10)
x_obs = np.asarray([0, 0, 1, 1, 0, 0, 0, 0, 0, 1], dtype=np.int32)

# inference
qz_samples = tf.Variable(tf.random_uniform(shape=(1000,)))
qz = ed.models.Empirical(params=qz_samples, name="z_posterior")
inference_hmc = ed.inferences.HMC({z: qz}, data={xs: x_obs})
inference_hmc.run(step_size=1.0, n_steps=5, auto_transform=True)

# check that inferred posterior mean/variance is close to
# that of the exact Beta posterior
z_unconstrained = inference_hmc.transformations[z]
qz_constrained = z_unconstrained.bijector.inverse(qz_samples)
qz_mean, qz_var = sess.run(tf.nn.moments(qz_constrained, 0))

true_posterior = Beta(1. + np.sum(x_obs), 1. + np.sum(1-x_obs))
pz_mean, pz_var = sess.run((true_posterior.mean(),
true_posterior.variance()))
self.assertAllClose(qz_mean, pz_mean, rtol=5e-2, atol=5e-2)
self.assertAllClose(qz_var, pz_var, rtol=1e-2, atol=1e-2)

def test_klqp_betabernoulli(self):
with self.test_session() as sess:
# model
z = Beta(1., 1., name="z")
xs = Bernoulli(probs=z, sample_shape=10)
x_obs = np.asarray([0, 0, 1, 1, 0, 0, 0, 0, 0, 1], dtype=np.int32)

# inference
qz_mean = tf.get_variable("qz_mean",
initializer=tf.random_normal(()))
qz_std = tf.nn.softplus(tf.get_variable(name="qz_prestd",
initializer=tf.random_normal(())))
qz_unconstrained = ed.models.Normal(loc=qz_mean, scale=qz_std, name="z_posterior")

inference_klqp = ed.inferences.KLqp({z: qz_unconstrained}, data={xs: x_obs})
inference_klqp.run(n_iter=500, auto_transform=True)

z_unconstrained = inference_klqp.transformations[z]
qz_constrained = z_unconstrained.bijector.inverse(qz_unconstrained.sample(1000))
qz_mean, qz_var = sess.run(tf.nn.moments(qz_constrained, 0))

true_posterior = Beta(np.sum(x_obs) + 1., np.sum(1-x_obs) + 1.)
pz_mean, pz_var = sess.run((true_posterior.mean(),
true_posterior.variance()))
self.assertAllClose(qz_mean, pz_mean, rtol=5e-2, atol=5e-2)
self.assertAllClose(qz_var, pz_var, rtol=1e-2, atol=1e-2)

if __name__ == '__main__':
ed.set_seed(124125)
tf.test.main()