Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FR] Relax _check_model_guide_enumeration_constraint in TraceEnum_ELBO #2809

Open
ordabayevy opened this issue Apr 19, 2021 · 1 comment
Open

Comments

@ordabayevy
Copy link
Member

ordabayevy commented Apr 19, 2021

Issue Description

_check_model_guide_enumeration_constraint is unnecessarily strict in TraceEnum_ELBO according to the conversation with @fritzo . This is a FR to relax the constraint or optionally allow to disable it.

(In my particular model there are two independent branches and model enumeration in one of them leads to false-positive ValueError raised by _check_model_guide_enumeration_constraint)

Simple model & guide example:

data1 = torch.tensor([3.2, 1.])
data2 = torch.tensor([1.3, 2.1, 3.4])

def model(data1, data2):
    locs = torch.tensor([1., 2.])
    with pyro.plate("plate1", 2):
        b1 = pyro.sample("b1", dist.Categorical(torch.ones(3)), infer={"enumerate": "parallel"})
        c1 = pyro.sample("c1", dist.Categorical(torch.ones(3,2)[b1]))
        pyro.sample("obs1", dist.Normal(locs[c1], 1), obs=data1)

    with pyro.plate("plate2", 3):
        c2 = pyro.sample("c2", dist.Categorical(torch.ones(2)))
        pyro.sample("obs2", dist.Normal(locs[c2], 1), obs=data2)

def guide(data1, data2):
    c1_probs = pyro.param("c1_probs", torch.ones(2, 2), constraint=constraints.simplex)
    with pyro.plate("plate1", 2):
        pyro.sample("c1", dist.Categorical(c1_probs), infer={"enumerate": "parallel"})

    c2_probs = pyro.param("c2_probs", torch.ones(3, 2), constraint=constraints.simplex)
    with pyro.plate("plate2", 3):
        pyro.sample("c2", dist.Categorical(c2_probs), infer={"enumerate": "parallel"})

raised error:

ValueError: Expected model enumeration to be no more global than guide enumeration, but found model enumeration sites upstream of guide site 'c2' in plate('plate2'). Try converting some model enumeration sites to guide enumeration sites.
@fritzo
Copy link
Member

fritzo commented Apr 20, 2021

Thanks for the clear reproducible example @ordabayevy!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants