Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,7 @@ function dot_observe(
increment_num_produce!(vi)
@debug "dists = $dists"
@debug "value = $value"
return sum(zip(dists, value)) do (d, v)
Distributions.loglikelihood(d, v)
end
return sum(Distributions.loglikelihood.(dists, value))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if Zygote is happy with

Suggested change
return sum(Distributions.loglikelihood.(dists, value))
return mapreduce(Distributions.loglikelihood, +, dists, value)

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be even a bit simpler than the original expression 🙂

Copy link
Member Author

@Red-Portal Red-Portal Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion Unfortunately, no. Zygote doesn't seem to like that. I get

ERROR: Can't differentiate loopinfo expression

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to be the sum bug that I mentioned in the original issue: FluxML/Zygote.jl#897 Just yesterday two PRs were opened that fix this problem, can you see if it works with these fixes? The relevant PR should be FluxML/Zygote.jl#956 but not completely sure, maybe it's the other one 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, it seems like both do not solve the problem probably because mapreduce has it's own implementation that does not rely on sum (pardon me if I'm wrong on this; see reduce.jl).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to summarize, Zygote is a bit annoying here... Maybe we should just rewrite the primal for Zygote in https://github.com/TuringLang/DynamicPPL.jl/blob/master/src/compat/ad.jl (we can use ZygoteRules and don't have to depend on Zygote). This would not impact performance and allocations in regular executions and with other AD backends.

Copy link
Member Author

@Red-Portal Red-Portal Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@devmotion This is quite down the rabbit hole. I'm really sorry to say this, but I currently don't have time to look deeper into this issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I'll try to add a fix later today.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened a PR with the changes that I had in mind: #236

end
function dot_observe(
spl::Sampler,
Expand Down