Skip to content

Commit

Permalink
fix test_klpq.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Feb 28, 2017
1 parent 9b12d9b commit 5ba1214
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions tests/test-inferences/test_klpq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,23 @@
class test_klpq_class(tf.test.TestCase):

def test_normalnormal_run(self):
return
# with self.test_session() as sess:
# x_data = np.array([0.0] * 50, dtype=np.float32)
with self.test_session() as sess:
x_data = np.array([0.0] * 50, dtype=np.float32)

# mu = Normal(mu=0.0, sigma=1.0)
# x = Normal(mu=tf.ones(50) * mu, sigma=1.0)
mu = Normal(mu=0.0, sigma=1.0)
x = Normal(mu=tf.ones(50) * mu, sigma=1.0)

# qmu_mu = tf.Variable(tf.random_normal([]))
# qmu_sigma = tf.nn.softplus(tf.Variable(tf.random_normal([])))
# qmu = Normal(mu=qmu_mu, sigma=qmu_sigma)
qmu_mu = tf.Variable(tf.random_normal([]))
qmu_sigma = tf.nn.softplus(tf.Variable(tf.random_normal([])))
qmu = Normal(mu=qmu_mu, sigma=qmu_sigma)

# # analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
# inference = ed.KLpq({mu: qmu}, data={x: x_data})
# inference.run(n_iter=5000)
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140)
inference = ed.KLpq({mu: qmu}, data={x: x_data})
inference.run(n_samples=25, n_iter=100)

# self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2)
# self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51),
# rtol=1e-2, atol=1e-2)
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1)
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51),
rtol=1e-1, atol=1e-1)

if __name__ == '__main__':
ed.set_seed(42)
Expand Down

0 comments on commit 5ba1214

Please sign in to comment.