Skip to content

Commit

Permalink
Test support of transformed distributions, with stacking and concaten…
Browse files Browse the repository at this point in the history
…ation transforms, in SplitReparam.
  • Loading branch information
Ben Zickel committed Jul 30, 2024
1 parent 6000812 commit e46f4bb
Showing 1 changed file with 60 additions and 21 deletions.
81 changes: 60 additions & 21 deletions tests/infer/reparam/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from .util import check_init_reparam


@pytest.mark.parametrize(
event_shape_splits_dim = pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
Expand All @@ -31,7 +30,13 @@
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


batch_shape = pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)


@event_shape_splits_dim
@batch_shape
def test_normal(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_()
Expand Down Expand Up @@ -72,24 +77,8 @@ def model():
assert_close(actual_grads, expected_grads)


@pytest.mark.parametrize(
"event_shape,splits,dim",
[
((6,), [2, 1, 3], -1),
(
(
2,
5,
),
[2, 3],
-1,
),
((4, 2), [1, 3], -2),
((2, 3, 1), [1, 2], -2),
],
ids=str,
)
@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str)
@event_shape_splits_dim
@batch_shape
def test_init(batch_shape, event_shape, splits, dim):
shape = batch_shape + event_shape
loc = torch.empty(shape).uniform_(-1.0, 1.0)
Expand All @@ -100,3 +89,53 @@ def model():
return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape)))

check_init_reparam(model, SplitReparam(splits, dim))


@batch_shape
def test_transformed_distribution(batch_shape):
num_samples = 10

transform = dist.transforms.StackTransform(
[
dist.transforms.OrderedTransform(),
dist.transforms.DiscreteCosineTransform(),
dist.transforms.HaarTransform(),
],
dim=-1,
)

num_transforms = len(transform.transforms)

def model():
scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1))
with pyro.plate_stack("plates", batch_shape):
x_dist = dist.TransformedDistribution(
dist.MultivariateNormal(
torch.zeros(num_samples, num_transforms), scale_tril=scale_tril
).to_event(1),
[transform],
)
return pyro.sample("x", x_dist)

assert model().shape == batch_shape + (num_samples, num_transforms)

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(model)
guide_sites = guide()

assert guide_sites["x"].shape == batch_shape + (num_samples, num_transforms)

for sections in [[1, 1, 1], [1, 2], [2, 1]]:
split_model = pyro.poutine.reparam(
model, config={"x": SplitReparam(sections, -1)}
)

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model)
guide_sites = guide()

for n, section in enumerate(sections):
assert guide_sites[f"x_split_{n}"].shape == batch_shape + (
num_samples,
section,
)

0 comments on commit e46f4bb

Please sign in to comment.