-
Notifications
You must be signed in to change notification settings - Fork 37
[Merged by Bors] - Fix Zygote issue with dot_observe
#236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Not much to say here I think 🤷♂️ Nice work, again :)
|
bors r+ |
This PR fixes TuringLang/Turing.jl#1595. It is an alternative to #235 that does not require us to rewrite the primal less efficiently which would affect regular execution and other AD backends. Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
dot_observedot_observe
| function dot_observe_fallback(spl, dists, value, vi) | ||
| increment_num_produce!(vi) | ||
| return sum(map(Distributions.loglikelihood, dists, value)) | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was just making some changes to #309 and noticed the adjoint introduced here. Isn't this wrong? If dists = [Normal()] and value = [1.0, 2.0], then we'd end up with sum([loglikelihood(Normal(), 1.0)]), right? 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but at least back then this case was not supported by the primal definition so it did not seem to be necessary to consider it here: https://github.com/TuringLang/DynamicPPL.jl/pull/235/files
It seems the primal definition was removed/changed though, maybe it's not even needed anymore?
This PR fixes TuringLang/Turing.jl#1595.
It is an alternative to #235 that does not require us to rewrite the primal less efficiently which would affect regular execution and other AD backends.