Skip to content

Commit

Permalink
Added test of Predictive with the SplitReparam reparametrizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Jul 26, 2024
1 parent 4adfccb commit d3d2936
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 24 deletions.
9 changes: 6 additions & 3 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

import pyro
import pyro.poutine as poutine
from pyro.infer.autoguide.initialization import InitMessenger, init_to_sample
from pyro.infer.importance import LogWeightsMixin
from pyro.infer.util import CloneMixin, plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites
from pyro.infer.autoguide.initialization import InitMessenger, init_to_median


def _guess_max_plate_nesting(model, args, kwargs):
"""
Expand Down Expand Up @@ -86,8 +87,10 @@ def _predictive(
mask=True,
):
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
initailized_model = InitMessenger(init_to_median)(model)
max_plate_nesting = _guess_max_plate_nesting(initailized_model, model_args, model_kwargs)
initailized_model = InitMessenger(init_to_sample)(model)
max_plate_nesting = _guess_max_plate_nesting(
initailized_model, model_args, model_kwargs
)
vectorize = pyro.plate(
_predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1
)
Expand Down
77 changes: 56 additions & 21 deletions tests/infer/reparam/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from .util import check_init_reparam


@pytest.mark.parametrize(
event_shape_splits_dim = pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
Expand All @@ -31,7 +30,13 @@
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


batch_shape = pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


@event_shape_splits_dim
@batch_shape
def test_normal(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
Expand Down Expand Up @@ -72,24 +77,8 @@ def model():
assert_close(actual_grads, expected_grads)


@pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
(
(
2,
5,
),
[2, 3],
-1,
),
((4, 2), [1, 3], -2),
((2, 3, 1), [1, 2], -2),
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)
@event_shape_splits_dim
@batch_shape
def test_init(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0)
Expand All @@ -100,3 +89,49 @@ def model():
return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape)))

check_init_reparam(model, SplitReparam(splits, dim))


@event_shape_splits_dim
@batch_shape
def test_predictive(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0)
scale = torch.empty(shape).uniform_(0.5, 1.5)

def model():
with pyro.plate_stack("plates", batch_shape):
pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape)))

# Reparametrize model
rep = SplitReparam(splits, dim)
reparam_model = poutine.reparam(model, {"x": rep})

# Fit guide to reparametrized model
guide = pyro.infer.autoguide.guides.AutoMultivariateNormal(reparam_model)
optimizer = pyro.optim.Adam(dict(lr=0.01))
loss = pyro.infer.JitTrace_ELBO(
num_particles=20, vectorize_particles=True, ignore_jit_warnings=True
)
svi = pyro.infer.SVI(reparam_model, guide, optimizer, loss)
for count in range(1001):
loss = svi.step()
if count % 100 == 0:
print(f"iteration {count} loss = {loss}")

# Sample from model using the guide
num_samples = 100000
parallel = True
sites = ["x_split_{}".format(i) for i in range(len(splits))]
values = pyro.infer.Predictive(
reparam_model,
guide=guide,
num_samples=num_samples,
parallel=parallel,
return_sites=sites,
)()

# Verify sampling
mean = torch.cat([values[site].mean(0) for site in sites], dim=dim)
std = torch.cat([values[site].std(0) for site in sites], dim=dim)
assert_close(mean, loc, atol=0.1)
assert_close(std, scale, rtol=0.1)

0 comments on commit d3d2936

Please sign in to comment.