Skip to content

Commit

Permalink
add fix also for MH
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Jan 20, 2017
1 parent ad6e45e commit cb24048
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 3 additions & 1 deletion edward/inferences/metropolis_hastings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import six
import tensorflow as tf

from collections import OrderedDict
from edward.inferences.monte_carlo import MonteCarlo
from edward.models import RandomVariable, Uniform
from edward.util import copy
Expand Down Expand Up @@ -67,6 +68,7 @@ def build_update(self):
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}
old_sample = OrderedDict(old_sample)

# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
Expand All @@ -85,7 +87,7 @@ def build_update(self):
scope_new = 'inference_' + str(id(self)) + '/new'

# Draw proposed sample and calculate acceptance ratio.
new_sample = {}
new_sample = old_sample.copy() # copy to ensure same order
ratio = 0.0
for z, proposal_z in six.iteritems(self.proposal_vars):
# Build proposal g(znew | zold).
Expand Down
11 changes: 10 additions & 1 deletion tests/test-inferences/test_bayesian_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def four_layer_nn(x, W_1, W_2, W_3, b_1, b_2):

class test_inference_bayesian_nn_class(tf.test.TestCase):

def test_hmc_sgld(self):
def test_monte_carlo(self):
ed.set_seed(42)

# DATA
Expand Down Expand Up @@ -47,6 +47,9 @@ def test_hmc_sgld(self):
qb_1 = Empirical(params=tf.Variable(tf.random_normal([T, 20])))
qb_2 = Empirical(params=tf.Variable(tf.random_normal([T, 15])))

# note ideally these would be separate test methods; there's an
# issue with the tensorflow graph when re-running the above
# unfortunately
inference = ed.HMC(
{W_1: qW_1, b_1: qb_1, W_2: qW_2, b_2: qb_2, W_3: qW_3},
data={y: y_train, x_ph: X_train})
Expand All @@ -57,5 +60,11 @@ def test_hmc_sgld(self):
data={y: y_train, x_ph: X_train})
inference.run()

inference = ed.MetropolisHastings(
{W_1: qW_1, b_1: qb_1, W_2: qW_2, b_2: qb_2, W_3: qW_3},
{W_1: W_1, b_1: b_1, W_2: W_2, b_2: b_2, W_3: W_3},
data={y: y_train, x_ph: X_train})
inference.run()

if __name__ == '__main__':
tf.test.main()

0 comments on commit cb24048

Please sign in to comment.