Skip to content

Commit

Permalink
Make Predictive work with the SplitReparam reparametrizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Jul 26, 2024
1 parent e3091e3 commit 4adfccb
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyro.infer.util import CloneMixin, plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites

from pyro.infer.autoguide.initialization import InitMessenger, init_to_median

def _guess_max_plate_nesting(model, args, kwargs):
"""
Expand Down Expand Up @@ -86,12 +86,13 @@ def _predictive(
mask=True,
):
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
initailized_model = InitMessenger(init_to_median)(model)
max_plate_nesting = _guess_max_plate_nesting(initailized_model, model_args, model_kwargs)
vectorize = pyro.plate(
_predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1
)
model_trace = prune_subsample_sites(
poutine.trace(model).get_trace(*model_args, **model_kwargs)
poutine.trace(initailized_model).get_trace(*model_args, **model_kwargs)
)
reshaped_samples = {}

Expand Down

0 comments on commit 4adfccb

Please sign in to comment.