Skip to content

Commit

Permalink
fix test_sgld.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Feb 28, 2017
1 parent b9d3495 commit 75f878e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/test-inferences/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_normalnormal_run(self):
mu = Normal(mu=0.0, sigma=1.0)
x = Normal(mu=tf.ones(50) * mu, sigma=1.0)

qmu = Empirical(params=tf.Variable(tf.ones([2000])))
qmu = Empirical(params=tf.Variable(tf.ones(2000)))

# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
inference = ed.HMC({mu: qmu}, data={x: x_data})
Expand Down
2 changes: 1 addition & 1 deletion tests/test-inferences/test_metropolishastings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_normalnormal_run(self):
mu = Normal(mu=0.0, sigma=1.0)
x = Normal(mu=tf.ones(50) * mu, sigma=1.0)

qmu = Empirical(params=tf.Variable(tf.ones([2000])))
qmu = Empirical(params=tf.Variable(tf.ones(2000)))
proposal_mu = Normal(mu=0.0, sigma=1.0)

# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
Expand Down
23 changes: 11 additions & 12 deletions tests/test-inferences/test_sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
class test_sgld_class(tf.test.TestCase):

def test_normalnormal_run(self):
return
# with self.test_session() as sess:
# x_data = np.array([0.0] * 50, dtype=np.float32)
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

# mu = Normal(mu=0.0, sigma=1.0)
# x = Normal(mu=tf.ones(50) * mu, sigma=1.0)
mu = Normal(mu=0.0, sigma=1.0)
x = Normal(mu=tf.ones(50) * mu, sigma=1.0)

# qmu = Empirical(params=tf.Variable(tf.ones([5000])))
qmu = Empirical(params=tf.Variable(tf.ones(5000)))

# # analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
# inference = ed.SGLD({mu: qmu}, data={x: x_data})
# inference.run(step_size=0.2)
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
inference = ed.SGLD({mu: qmu}, data={x: x_data})
inference.run(step_size=0.2)

# self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2)
# self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51),
# rtol=1e-2, atol=1e-2)
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2)
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51),
rtol=2e-2, atol=2e-2)

if __name__ == '__main__':
ed.set_seed(42)
Expand Down

0 comments on commit 75f878e

Please sign in to comment.