Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Effect handler that conditions a model on sample sites having the same value #3395

Merged
merged 4 commits into from
Sep 20, 2024

Conversation

BenZickel
Copy link
Contributor

Problem Description

It would be helpful to have an effect handler that can condition a model on sample sites having the same value.

Suggested Solution

Use the EqualizeMessenger effect handler with a newly added option keep_dist, that when set to True keeps the original distribution functions of the sample sites, as opposed to the default behavior of converting the second and subsequent sites to be deterministic.

Usage Example

Consider the model

def model():
    x = pyro.sample('x', pyro.distributions.Normal(0, 1))
    y = pyro.sample('y', pyro.distributions.Normal(5, 3))

The model can be conditioned on ‘x’ and ‘y’ having the same value by

conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)

which is equivalent to

def conditioned_model():
    x = pyro.sample('x', pyro.distributions.Normal(0, 1))
    y = pyro.sample('y', pyro.distributions.Normal(5, 3), obs=x)

as opposed to the default behavior of EqualizeMessenger with keep_dist equal to False such that

equalized_model = pyro.poutine.equalize(model, ['x', 'y'])

which is equivalent to

def equalized_model():
    x = pyro.sample('x', pyro.distributions.Normal(0, 1))
    y = pyro.deterministic('y', x)

Note that the conditioned model defined above calculates the correct unnormalized log-probablity density, but in order to correctly sample from it one must use SVI or MCMC techniques.

Testing

I've added a test for the conditioned model case, with two normally distributed random variables, which allows for analytic calculation of the expected resulting normal distribution.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @BenZickel!

@fehiepsi fehiepsi merged commit b3c7851 into pyro-ppl:dev Sep 20, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants