From d3d293665e02cc215926b20b7ffb599c9fc25013 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Sat, 27 Jul 2024 01:52:43 +0300 Subject: [PATCH] Added test of Predictive with the SplitReparam reparametrizer. --- pyro/infer/predictive.py | 9 ++-- tests/infer/reparam/test_split.py | 77 ++++++++++++++++++++++--------- 2 files changed, 62 insertions(+), 24 deletions(-) diff --git a/pyro/infer/predictive.py b/pyro/infer/predictive.py index 3cf06b2827..b57a193f4d 100644 --- a/pyro/infer/predictive.py +++ b/pyro/infer/predictive.py @@ -10,11 +10,12 @@ import pyro import pyro.poutine as poutine +from pyro.infer.autoguide.initialization import InitMessenger, init_to_sample from pyro.infer.importance import LogWeightsMixin from pyro.infer.util import CloneMixin, plate_log_prob_sum from pyro.poutine.trace_struct import Trace from pyro.poutine.util import prune_subsample_sites -from pyro.infer.autoguide.initialization import InitMessenger, init_to_median + def _guess_max_plate_nesting(model, args, kwargs): """ @@ -86,8 +87,10 @@ def _predictive( mask=True, ): model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model) - initailized_model = InitMessenger(init_to_median)(model) - max_plate_nesting = _guess_max_plate_nesting(initailized_model, model_args, model_kwargs) + initailized_model = InitMessenger(init_to_sample)(model) + max_plate_nesting = _guess_max_plate_nesting( + initailized_model, model_args, model_kwargs + ) vectorize = pyro.plate( _predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1 ) diff --git a/tests/infer/reparam/test_split.py b/tests/infer/reparam/test_split.py index 6337069ea0..fb450c4220 100644 --- a/tests/infer/reparam/test_split.py +++ b/tests/infer/reparam/test_split.py @@ -13,8 +13,7 @@ from .util import check_init_reparam - -@pytest.mark.parametrize( +event_shape_splits_dim = pytest.mark.parametrize( "event_shape,splits,dim", [ ((6,), [2, 1, 3], -1), @@ -31,7 +30,13 @@ ], ids=str, ) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) + + +batch_shape = pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) + + +@event_shape_splits_dim +@batch_shape def test_normal(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() @@ -72,24 +77,8 @@ def model(): assert_close(actual_grads, expected_grads) -@pytest.mark.parametrize( - "event_shape,splits,dim", - [ - ((6,), [2, 1, 3], -1), - ( - ( - 2, - 5, - ), - [2, 3], - -1, - ), - ((4, 2), [1, 3], -2), - ((2, 3, 1), [1, 2], -2), - ], - ids=str, -) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@event_shape_splits_dim +@batch_shape def test_init(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1.0, 1.0) @@ -100,3 +89,49 @@ def model(): return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape))) check_init_reparam(model, SplitReparam(splits, dim)) + + +@event_shape_splits_dim +@batch_shape +def test_predictive(batch_shape, event_shape, splits, dim): + shape = batch_shape + event_shape + loc = torch.empty(shape).uniform_(-1.0, 1.0) + scale = torch.empty(shape).uniform_(0.5, 1.5) + + def model(): + with pyro.plate_stack("plates", batch_shape): + pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape))) + + # Reparametrize model + rep = SplitReparam(splits, dim) + reparam_model = poutine.reparam(model, {"x": rep}) + + # Fit guide to reparametrized model + guide = pyro.infer.autoguide.guides.AutoMultivariateNormal(reparam_model) + optimizer = pyro.optim.Adam(dict(lr=0.01)) + loss = pyro.infer.JitTrace_ELBO( + num_particles=20, vectorize_particles=True, ignore_jit_warnings=True + ) + svi = pyro.infer.SVI(reparam_model, guide, optimizer, loss) + for count in range(1001): + loss = svi.step() + if count % 100 == 0: + print(f"iteration {count} loss = {loss}") + + # Sample from model using the guide + num_samples = 100000 + parallel = True + sites = ["x_split_{}".format(i) for i in range(len(splits))] + values = pyro.infer.Predictive( + reparam_model, + guide=guide, + num_samples=num_samples, + parallel=parallel, + return_sites=sites, + )() + + # Verify sampling + mean = torch.cat([values[site].mean(0) for site in sites], dim=dim) + std = torch.cat([values[site].std(0) for site in sites], dim=dim) + assert_close(mean, loc, atol=0.1) + assert_close(std, scale, rtol=0.1)