-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] compute_expectation wrt plate and markov variables #413
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eb8680 can you review this?
Some notes:
TraceMarkovEnum_ELBO
that usescompute_expectation
passes all of the tests locally for updatedtest_vectorized_markov::test_guide_enumerated_elbo
from Make replay work correctly with vectorized_markov pyro#2726- This PR can be simplified if
funsor.adjoint
can be used instead offorward_backward_terms
function written here.
@ordabayevy I'm impressed you got this working, and hopefully your model should be unblocked at least locally via this PR and pyro-ppl/pyro#2725 . Please let me know if that's not the case. However, I am reluctant to merge this in its current form, especially without any tests - the algorithm implemented here is quite complex, and there is a substantial amount of code duplication from Can we talk sometime this week about how it might be simplified? I have some intuitions but not a fully fleshed out plan. It would probably be useful for both of us to prepare by reading this paper, which discusses general dynamic programming algorithms for computing expectations. |
Okay, I'll let you know if there are any problems switching to Funsor backend and trying out this version of
Thursday and Friday will work for me. |
…ential-integral merge master
@eb8680 I've created an HMM version of my model where both |
@ordabayevy that's great! I'll look into making That should be as simple as creating a from pyro.contrib.funsor.infer.elbo import Jit_ELBO
...
class JitTraceMarkovEnum_ELBO(Jit_ELBO, TraceMarkovEnum_ELBO):
pass For a |
Closing as obsolete. |
Addresses pyro-ppl/pyro#2724.
This function computes the expected value of cost terms wrt to log measures.