Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] compute_expectation wrt plate and markov variables #413

Closed
wants to merge 22 commits into from

Conversation

ordabayevy
Copy link
Member

Addresses pyro-ppl/pyro#2724.

This function computes the expected value of cost terms wrt to log measures.

@ordabayevy ordabayevy changed the title compute_expectation [WIP] compute_expectation wrt plate and markov variables Dec 30, 2020
Copy link
Member Author

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 can you review this?
Some notes:

  • TraceMarkovEnum_ELBO that uses compute_expectation passes all of the tests locally for updated test_vectorized_markov::test_guide_enumerated_elbo from Make replay work correctly with vectorized_markov pyro#2726
  • This PR can be simplified if funsor.adjoint can be used instead of forward_backward_terms function written here.

@eb8680
Copy link
Member

eb8680 commented Jan 3, 2021

@ordabayevy I'm impressed you got this working, and hopefully your model should be unblocked at least locally via this PR and pyro-ppl/pyro#2725 . Please let me know if that's not the case.

However, I am reluctant to merge this in its current form, especially without any tests - the algorithm implemented here is quite complex, and there is a substantial amount of code duplication from modified_partial_sum_product that I would prefer to reduce or avoid if possible given the changes we expect to make to the internals of modified_partial_sum_product in the future (#293). It's also not yet clear to me how we will need to change the algorithm or interface to facilitate sharing of computation across multiple integrands, as is done (somewhat imperfectly) in Pyro's pyro.infer.util.Dice.compute_expectation.

Can we talk sometime this week about how it might be simplified? I have some intuitions but not a fully fleshed out plan. It would probably be useful for both of us to prepare by reading this paper, which discusses general dynamic programming algorithms for computing expectations.

@ordabayevy
Copy link
Member Author

Please let me know if that's not the case.

Okay, I'll let you know if there are any problems switching to Funsor backend and trying out this version of TraceMarkovEnum_ELBO.

Can we talk sometime this week about how it might be simplified?

Thursday and Friday will work for me.

@ordabayevy
Copy link
Member Author

@eb8680 I've created an HMM version of my model where both m and theta are guide enumerated. I've tested it on simulated and real data and it seems to work! Although a bit slow which might be because compute_expectation works only in eager interpretation. As we discussed last week lazy implementation is blocked by the _scatter function which doesn't have a lazy equivalent right now.

@eb8680
Copy link
Member

eb8680 commented Jan 13, 2021

@ordabayevy that's great! I'll look into making _scatter a first-class funsor. One other thing you can try to speed it up is using the PyTorch JIT - it won't make the underlying tensor operations any faster but it should at least compile away the Funsor pattern-matching overhead.

That should be as simple as creating a JitTraceMarkovEnum_ELBO class similar to JitTraceEnum_ELBO and dropping it into your Pyro code:

from pyro.contrib.funsor.infer.elbo import Jit_ELBO
...
class JitTraceMarkovEnum_ELBO(Jit_ELBO, TraceMarkovEnum_ELBO):
    pass

For a Jit*_ELBO usage example see https://github.com/pyro-ppl/pyro/blob/dev/examples/contrib/funsor/hmm.py

@ordabayevy ordabayevy marked this pull request as draft January 16, 2021 02:34
@ordabayevy
Copy link
Member Author

Closing as obsolete.

@ordabayevy ordabayevy closed this Mar 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants