Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Predictive work with the SplitReparam reparameterizer [bugfix] #3388

Merged
merged 7 commits into from
Aug 4, 2024
Next Next commit
Make Predictive work with the SplitReparam reparametrizer.
  • Loading branch information
Ben Zickel committed Jul 26, 2024
commit 4adfccbcd5c5cd8c0525f72268538c01f50667d3
7 changes: 4 additions & 3 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyro.infer.util import CloneMixin, plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites

from pyro.infer.autoguide.initialization import InitMessenger, init_to_median

def _guess_max_plate_nesting(model, args, kwargs):
"""
Expand Down Expand Up @@ -86,12 +86,13 @@ def _predictive(
mask=True,
):
model = torch.no_grad()(poutine.mask(model, mask=False) if mask else model)
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
initailized_model = InitMessenger(init_to_median)(model)
max_plate_nesting = _guess_max_plate_nesting(initailized_model, model_args, model_kwargs)
vectorize = pyro.plate(
_predictive_vectorize_plate_name, num_samples, dim=-max_plate_nesting - 1
)
model_trace = prune_subsample_sites(
poutine.trace(model).get_trace(*model_args, **model_kwargs)
poutine.trace(initailized_model).get_trace(*model_args, **model_kwargs)
)
reshaped_samples = {}

Expand Down