-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Require backends #95
Require backends #95
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, and IMO it's much easier to review properly in this way! I added some preliminary comments (I haven't checked all details yet).
move functions from tracker.jl to common.jl and call them for both Tracker/ReverseDiff backends Also gets rid of the ZygoteRules dependency for the ReverseDiff backend
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
…sAD.jl into require-backends
test/PosDef.jl
Outdated
@@ -0,0 +1,28 @@ | |||
module PosDef |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to put it in a separate module? I would be surprised at least if the use of Requires and separate modules are needed to define the adjoint for Tracker.
I would have assumed that something like
if AD == "All" || AD == "Tracker"
include("tracker_utils.jl")
end
in test/ad/distributions.jl would be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that the TrackedMatrix
overload is compiled even if Tracker
is not loaded...
I tried it with Requires (and with include
) in test/ad/distributions.jl
which didn't work, but defining is_posdef
in test/runtests.jl
and requiring the Tracker stuff there does the job. So there is no more separate module, but there might be an obvious solution I am not seeing?
Sorry devmotion, read your comment to late... This reverts commit 09d6cd4.
There is a file what to do with it? |
@nmheim These codes should probably live in the package |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a few suggestions for the ChainRules overloads
@@ -3,14 +3,13 @@ uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" | |||
version = "0.6.2" | |||
|
|||
[deps] | |||
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" | |||
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there anything in ChainRules specifically that you need, or could you use ChainRulesCore, which is much lighter? There's also ChainRulesTestUtils.rrule_test
, which makes it easy to test the rules directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, thanks a lot for all the suggestions! :)
ChainRules is needed for now, because we are using chol_blocked_rev
here:
DistributionsAD.jl/src/common.jl
Line 11 in 5ec5201
∂X = @thunk(ChainRules.chol_blocked_rev(f̄, factors, 25, true)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're welcome! That makes sense. I added an alternate suggestion below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SInce you define turing_chol
as
function turing_chol(A::AbstractMatrix, check)
chol = cholesky(A, check=check)
(chol.factors, chol.info)
end
couldn't you define its rrule
as (something like)
function ChainRulesCore.rrule(::typeof(turing_chol), A::AbstractMatrix, check)
chol, inner_pullback = rrule(cholesky, A, check=check)
Ω = (chol.factors, chol.info)
function turing_chol_pullback(ΔΩ)
(Δfactors, Δinfo) = ΔΩ
Δchol = Composite{typeof(chol)}(factors = Δfactors, info = Δinfo)
(_, ∂A, ∂check) = inner_pullback(Δchol)
return (NO_FIELDS, A, ∂check)
end
return Ω, turing_chol_pullback
end
? This would be a bit safer, as it doesn't use ChainRules internal functions.
The check
argument isn't currently supported by ChainRules. I'll fix.
return upper(C.factors) | ||
else | ||
return copy(lower(C.factors)') | ||
function ChainRules.rrule(::typeof(turing_chol), A::AbstractMatrix, check) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently chol_blocked_rev
is only tested for real matrices. Probably the right thing to do is relax that constraint on ChainRules and test (I'll open a PR).
I will push the suggestions by @sethaxen once ChainRules can deal with keyword arguments. Then this should be ready. |
Ok, it will take some more time until the keyword stuff for the |
This PR tries to separate the different AD backends into files that can be loaded optionally via
Requires.jl
.TODO
test/others.jl
zygote.jl
back to original functions definitions