Skip to content

Error when using infer_discrete for a model with local variable that depends on discrete variable #2860

Open
@ordabayevy

Description

Issue Description

I'm trying to use infer_discrete for a model (code snippet below) which has local latent variable (locs) that depends on discrete variable (assignment). This leads to an error where site["log_prob"].shape and dim_to_symbol don't match:

  File "/home/ordabayev/repos/pyro/pyro/poutine/trace_struct.py", line 376, in pack_tensors
    packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
  File "/home/ordabayev/repos/pyro/pyro/ops/packed.py", line 29, in pack
    raise ValueError('\n  '.join([
ValueError: Error while packing tensors at site 'locs':
  Invalid tensor shape.
  Allowed dims: -1
  Actual shape: (2, 5)

Using pdb:

-> packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
(Pdb) p site["log_prob"].shape
torch.Size([2, 5])
(Pdb) p dim_to_symbol
{-1: 'a'}

Environment

  • Pyro dev branch

Code Snippet

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

data = torch.tensor([0., 1., 10., 11., 12.])

K = 2  # Fixed number of components.


@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    clusters = torch.tensor([0., 10.])

    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        locs = pyro.sample('locs', dist.Normal(clusters[assignment], 2.))
        pyro.sample('obs', dist.Normal(locs, scale), obs=data)


def guide(data):
    # Global variables.
    pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    pyro.sample('scale', dist.LogNormal(0., 2.))

    with pyro.plate('data', len(data)):
        # Local variables.
        pyro.sample('locs', dist.Normal(10., 2.))


optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)

guide_trace = poutine.trace(guide).get_trace(data)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)  # replay the globals


inferred_model = infer_discrete(trained_model, temperature=1,
                                first_available_dim=-2)  # avoid conflict with data plate
trace = poutine.trace(inferred_model).get_trace(data)
print(trace.nodes["assignment"]["value"])

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions