Error when using infer_discrete
for a model with local variable that depends on discrete variable #2860
Open
Description
Issue Description
I'm trying to use infer_discrete
for a model (code snippet below) which has local latent variable (locs
) that depends on discrete variable (assignment
). This leads to an error where site["log_prob"].shape
and dim_to_symbol
don't match:
File "/home/ordabayev/repos/pyro/pyro/poutine/trace_struct.py", line 376, in pack_tensors
packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
File "/home/ordabayev/repos/pyro/pyro/ops/packed.py", line 29, in pack
raise ValueError('\n '.join([
ValueError: Error while packing tensors at site 'locs':
Invalid tensor shape.
Allowed dims: -1
Actual shape: (2, 5)
Using pdb:
-> packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
(Pdb) p site["log_prob"].shape
torch.Size([2, 5])
(Pdb) p dim_to_symbol
{-1: 'a'}
Environment
- Pyro dev branch
Code Snippet
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete
data = torch.tensor([0., 1., 10., 11., 12.])
K = 2 # Fixed number of components.
@config_enumerate
def model(data):
# Global variables.
weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
scale = pyro.sample('scale', dist.LogNormal(0., 2.))
clusters = torch.tensor([0., 10.])
with pyro.plate('data', len(data)):
# Local variables.
assignment = pyro.sample('assignment', dist.Categorical(weights))
locs = pyro.sample('locs', dist.Normal(clusters[assignment], 2.))
pyro.sample('obs', dist.Normal(locs, scale), obs=data)
def guide(data):
# Global variables.
pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
pyro.sample('scale', dist.LogNormal(0., 2.))
with pyro.plate('data', len(data)):
# Local variables.
pyro.sample('locs', dist.Normal(10., 2.))
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
guide_trace = poutine.trace(guide).get_trace(data) # record the globals
trained_model = poutine.replay(model, trace=guide_trace) # replay the globals
inferred_model = infer_discrete(trained_model, temperature=1,
first_available_dim=-2) # avoid conflict with data plate
trace = poutine.trace(inferred_model).get_trace(data)
print(trace.nodes["assignment"]["value"])