Skip to content

Commit

Permalink
Support for transformed distributions, based on stacking or concatena…
Browse files Browse the repository at this point in the history
…tion transforms, in SplitReparam (#3390)
  • Loading branch information
BenZickel authored Aug 4, 2024
1 parent 6130da0 commit 5cebc44
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 2 deletions.
68 changes: 66 additions & 2 deletions pyro/infer/reparam/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,61 @@
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.distributions.torch_distribution import TorchDistributionMixin

from .reparam import Reparam


def same_support(fn: TorchDistributionMixin, *args):
"""
Returns support of the `fn` distribution. Used in :class:`SplitReparam` in
order to determine the support of the split value.
:param fn: distribution class
:returns: distribution support
"""
return fn.support


def real_support(fn: TorchDistributionMixin, *args):
"""
Returns real support with same event dimension as that of the `fn` distribution.
Used in :class:`SplitReparam` in order to determine the support of the split value.
:param fn: distribution class
:returns: distribution support
"""
return dist.constraints.independent(dist.constraints.real, fn.event_dim)


def default_support(fn: TorchDistributionMixin, slice, dim):
"""
Returns support of the `fn` distribution, corrected for split stacking and
concatenation transforms. Used in :class:`SplitReparam` in
order to determine the support of the split value.
:param fn: distribution class
:param slice: slice for which to return support
:param dim: dimension for which to return support
:returns: distribution support
"""
support = fn.support
# Unwrap support
reinterpreted_batch_ndims_vec = []
while isinstance(support, dist.constraints.independent):
reinterpreted_batch_ndims_vec.append(support.reinterpreted_batch_ndims)
support = support.base_constraint
# Slice concatenation and stacking transforms
if isinstance(support, dist.constraints.stack) and support.dim == dim:
support = dist.constraints.stack(support.cseq[slice], dim)
elif isinstance(support, dist.constraints.cat) and support.dim == dim:
support = dist.constraints.cat(support.cseq[slice], dim, support.lengths[slice])
# Wrap support
for reinterpreted_batch_ndims in reinterpreted_batch_ndims_vec[::-1]:
support = dist.constraints.independent(support, reinterpreted_batch_ndims)
return support


class SplitReparam(Reparam):
"""
Reparameterizer to split a random variable along a dimension, similar to
Expand All @@ -28,14 +79,21 @@ class SplitReparam(Reparam):
each chunk.
:type: list(int)
:param int dim: Dimension along which to split. Defaults to -1.
:param callable support_fn: Function which derives the split support
from the site's sampling function, split size, and split dimension.
Default is :func:`default_support` which correctly handles stacking
and concatenation transforms. Other options are :func:`same_support`
which returns the same support as that of the sampling function, and
:func:`real_support` which returns a real support.
"""

def __init__(self, sections, dim):
def __init__(self, sections, dim, support_fn=default_support):
assert isinstance(dim, int) and dim < 0
assert isinstance(sections, list)
assert all(isinstance(size, int) for size in sections)
self.event_dim = -dim
self.sections = sections
self.support_fn = support_fn

def apply(self, msg):
name = msg["name"]
Expand All @@ -53,14 +111,20 @@ def apply(self, msg):
dim = fn.event_dim - self.event_dim
left_shape = fn.event_shape[:dim]
right_shape = fn.event_shape[1 + dim :]
start = 0
for i, size in enumerate(self.sections):
event_shape = left_shape + (size,) + right_shape
value_split[i] = pyro.sample(
f"{name}_split_{i}",
dist.ImproperUniform(fn.support, fn.batch_shape, event_shape),
dist.ImproperUniform(
self.support_fn(fn, slice(start, start + size), -self.event_dim),
fn.batch_shape,
event_shape,
),
obs=value_split[i],
infer={"is_observed": is_observed},
)
start += size

# Combine parts into value.
if value is None:
Expand Down
50 changes: 50 additions & 0 deletions tests/infer/reparam/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,56 @@ def model():
check_init_reparam(model, SplitReparam(splits, dim))


@batch_shape
def test_transformed_distribution(batch_shape):
num_samples = 10

transform = dist.transforms.StackTransform(
[
dist.transforms.OrderedTransform(),
dist.transforms.DiscreteCosineTransform(),
dist.transforms.HaarTransform(),
],
dim=-1,
)

num_transforms = len(transform.transforms)

def model():
scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1))
with pyro.plate_stack("plates", batch_shape):
x_dist = dist.TransformedDistribution(
dist.MultivariateNormal(
torch.zeros(num_samples, num_transforms), scale_tril=scale_tril
).to_event(1),
[transform],
)
return pyro.sample("x", x_dist)

assert model().shape == batch_shape + (num_samples, num_transforms)

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(model)
guide_sites = guide()

assert guide_sites["x"].shape == batch_shape + (num_samples, num_transforms)

for sections in [[1, 1, 1], [1, 2], [2, 1]]:
split_model = pyro.poutine.reparam(
model, config={"x": SplitReparam(sections, -1)}
)

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model)
guide_sites = guide()

for n, section in enumerate(sections):
assert guide_sites[f"x_split_{n}"].shape == batch_shape + (
num_samples,
section,
)


@event_shape_splits_dim
@batch_shape
def test_predictive(batch_shape, event_shape, splits, dim):
Expand Down

0 comments on commit 5cebc44

Please sign in to comment.