-
Notifications
You must be signed in to change notification settings - Fork 760
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement a few inference tests (#487)
* basic test file structure * klpq and map tests * cleanup vi tests, add mh test * more MCMC tests * found magic combination for SGLD to converge * fixed map, disabled sgld and klpq tests * fix test_sgld.py * fix test_klpq.py
- Loading branch information
1 parent
f20c8f8
commit a024c5e
Showing
6 changed files
with
203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from edward.models import Normal, Empirical | ||
|
||
|
||
class test_hmc_class(tf.test.TestCase): | ||
|
||
def test_normalnormal_run(self): | ||
with self.test_session() as sess: | ||
x_data = np.array([0.0] * 50, dtype=np.float32) | ||
|
||
mu = Normal(mu=0.0, sigma=1.0) | ||
x = Normal(mu=tf.ones(50) * mu, sigma=1.0) | ||
|
||
qmu = Empirical(params=tf.Variable(tf.ones(2000))) | ||
|
||
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140) | ||
inference = ed.HMC({mu: qmu}, data={x: x_data}) | ||
inference.run() | ||
|
||
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2) | ||
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51), | ||
rtol=1e-2, atol=1e-2) | ||
|
||
if __name__ == '__main__': | ||
ed.set_seed(42) | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from edward.models import Normal | ||
|
||
|
||
class test_klpq_class(tf.test.TestCase): | ||
|
||
def test_normalnormal_run(self): | ||
with self.test_session() as sess: | ||
x_data = np.array([0.0] * 50, dtype=np.float32) | ||
|
||
mu = Normal(mu=0.0, sigma=1.0) | ||
x = Normal(mu=tf.ones(50) * mu, sigma=1.0) | ||
|
||
qmu_mu = tf.Variable(tf.random_normal([])) | ||
qmu_sigma = tf.nn.softplus(tf.Variable(tf.random_normal([]))) | ||
qmu = Normal(mu=qmu_mu, sigma=qmu_sigma) | ||
|
||
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140) | ||
inference = ed.KLpq({mu: qmu}, data={x: x_data}) | ||
inference.run(n_samples=25, n_iter=100) | ||
|
||
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) | ||
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51), | ||
rtol=1e-1, atol=1e-1) | ||
|
||
if __name__ == '__main__': | ||
ed.set_seed(42) | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from edward.models import Normal | ||
|
||
|
||
class test_klqp_class(tf.test.TestCase): | ||
|
||
def test_normalnormal_run(self): | ||
with self.test_session() as sess: | ||
x_data = np.array([0.0] * 50, dtype=np.float32) | ||
|
||
mu = Normal(mu=0.0, sigma=1.0) | ||
x = Normal(mu=tf.ones(50) * mu, sigma=1.0) | ||
|
||
qmu_mu = tf.Variable(tf.random_normal([])) | ||
qmu_sigma = tf.nn.softplus(tf.Variable(tf.random_normal([]))) | ||
qmu = Normal(mu=qmu_mu, sigma=qmu_sigma) | ||
|
||
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140) | ||
inference = ed.KLqp({mu: qmu}, data={x: x_data}) | ||
inference.run(n_iter=5000) | ||
|
||
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2) | ||
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51), | ||
rtol=1e-2, atol=1e-2) | ||
|
||
if __name__ == '__main__': | ||
ed.set_seed(42) | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from edward.models import Normal, PointMass | ||
|
||
|
||
class test_map_class(tf.test.TestCase): | ||
|
||
def test_normalnormal_run(self): | ||
with self.test_session() as sess: | ||
x_data = np.array([0.0] * 50, dtype=np.float32) | ||
|
||
mu = Normal(mu=0.0, sigma=1.0) | ||
x = Normal(mu=tf.ones(50) * mu, sigma=1.0) | ||
|
||
qmu = PointMass(params=tf.Variable(tf.ones([]))) | ||
|
||
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140) | ||
inference = ed.MAP({mu: qmu}, data={x: x_data}) | ||
inference.run(n_iter=1000) | ||
|
||
self.assertAllClose(qmu.mean().eval(), 0) | ||
|
||
if __name__ == '__main__': | ||
ed.set_seed(42) | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from edward.models import Normal, Empirical | ||
|
||
|
||
class test_metropolishastings_class(tf.test.TestCase): | ||
|
||
def test_normalnormal_run(self): | ||
with self.test_session() as sess: | ||
x_data = np.array([0.0] * 50, dtype=np.float32) | ||
|
||
mu = Normal(mu=0.0, sigma=1.0) | ||
x = Normal(mu=tf.ones(50) * mu, sigma=1.0) | ||
|
||
qmu = Empirical(params=tf.Variable(tf.ones(2000))) | ||
proposal_mu = Normal(mu=0.0, sigma=1.0) | ||
|
||
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140) | ||
inference = ed.MetropolisHastings({mu: qmu}, | ||
{mu: proposal_mu}, | ||
data={x: x_data}) | ||
inference.run() | ||
|
||
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2) | ||
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51), | ||
rtol=1e-2, atol=1e-2) | ||
|
||
if __name__ == '__main__': | ||
ed.set_seed(42) | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import edward as ed | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from edward.models import Normal, Empirical | ||
|
||
|
||
class test_sgld_class(tf.test.TestCase): | ||
|
||
def test_normalnormal_run(self): | ||
with self.test_session() as sess: | ||
x_data = np.array([0.0] * 50, dtype=np.float32) | ||
|
||
mu = Normal(mu=0.0, sigma=1.0) | ||
x = Normal(mu=tf.ones(50) * mu, sigma=1.0) | ||
|
||
qmu = Empirical(params=tf.Variable(tf.ones(5000))) | ||
|
||
# analytic solution: N(mu=0.0, sigma=\sqrt{1/51}=0.140) | ||
inference = ed.SGLD({mu: qmu}, data={x: x_data}) | ||
inference.run(step_size=0.2) | ||
|
||
self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-2, atol=1e-2) | ||
self.assertAllClose(qmu.std().eval(), np.sqrt(1 / 51), | ||
rtol=5e-2, atol=5e-2) | ||
|
||
if __name__ == '__main__': | ||
ed.set_seed(42) | ||
tf.test.main() |