diff --git a/tests/infer/reparam/test_split.py b/tests/infer/reparam/test_split.py index 6337069ea0..0167f5778d 100644 --- a/tests/infer/reparam/test_split.py +++ b/tests/infer/reparam/test_split.py @@ -13,8 +13,7 @@ from .util import check_init_reparam - -@pytest.mark.parametrize( +event_shape_splits_dim = pytest.mark.parametrize( "event_shape,splits,dim", [ ((6,), [2, 1, 3], -1), @@ -31,7 +30,13 @@ ], ids=str, ) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) + + +batch_shape = pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) + + +@event_shape_splits_dim +@batch_shape def test_normal(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1.0, 1.0).requires_grad_() @@ -72,24 +77,8 @@ def model(): assert_close(actual_grads, expected_grads) -@pytest.mark.parametrize( - "event_shape,splits,dim", - [ - ((6,), [2, 1, 3], -1), - ( - ( - 2, - 5, - ), - [2, 3], - -1, - ), - ((4, 2), [1, 3], -2), - ((2, 3, 1), [1, 2], -2), - ], - ids=str, -) -@pytest.mark.parametrize("batch_shape", [(), (4,), (3, 2)], ids=str) +@event_shape_splits_dim +@batch_shape def test_init(batch_shape, event_shape, splits, dim): shape = batch_shape + event_shape loc = torch.empty(shape).uniform_(-1.0, 1.0) @@ -100,3 +89,53 @@ def model(): return pyro.sample("x", dist.Normal(loc, scale).to_event(len(event_shape))) check_init_reparam(model, SplitReparam(splits, dim)) + + +@batch_shape +def test_transformed_distribution(batch_shape): + num_samples = 10 + + transform = dist.transforms.StackTransform( + [ + dist.transforms.OrderedTransform(), + dist.transforms.DiscreteCosineTransform(), + dist.transforms.HaarTransform(), + ], + dim=-1, + ) + + num_transforms = len(transform.transforms) + + def model(): + scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1)) + with pyro.plate_stack("plates", batch_shape): + x_dist = dist.TransformedDistribution( + dist.MultivariateNormal( + torch.zeros(num_samples, num_transforms), scale_tril=scale_tril + ).to_event(1), + [transform], + ) + return pyro.sample("x", x_dist) + + assert model().shape == batch_shape + (num_samples, num_transforms) + + pyro.clear_param_store() + guide = pyro.infer.autoguide.AutoMultivariateNormal(model) + guide_sites = guide() + + assert guide_sites["x"].shape == batch_shape + (num_samples, num_transforms) + + for sections in [[1, 1, 1], [1, 2], [2, 1]]: + split_model = pyro.poutine.reparam( + model, config={"x": SplitReparam(sections, -1)} + ) + + pyro.clear_param_store() + guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model) + guide_sites = guide() + + for n, section in enumerate(sections): + assert guide_sites[f"x_split_{n}"].shape == batch_shape + ( + num_samples, + section, + )