Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require backends #95

Merged
merged 29 commits into from
Jul 26, 2020
Merged

Require backends #95

merged 29 commits into from
Jul 26, 2020

Conversation

nmheim
Copy link
Collaborator

@nmheim nmheim commented Jun 30, 2020

This PR tries to separate the different AD backends into files that can be loaded optionally via Requires.jl.

TODO

  • univariate distributions
  • multivariate distributions
  • matrix variate distributions
  • product distributions
  • test each backend separately
  • ForwardDiff tests for test/others.jl
  • move adjoints in zygote.jl back to original functions definitions

@nmheim nmheim requested a review from devmotion June 30, 2020 15:35
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, and IMO it's much easier to review properly in this way! I added some preliminary comments (I haven't checked all details yet).

src/DistributionsAD.jl Outdated Show resolved Hide resolved
src/DistributionsAD.jl Outdated Show resolved Hide resolved
src/DistributionsAD.jl Outdated Show resolved Hide resolved
src/DistributionsAD.jl Outdated Show resolved Hide resolved
src/reversediff.jl Outdated Show resolved Hide resolved
src/tracker.jl Outdated Show resolved Hide resolved
src/tracker.jl Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
@nmheim nmheim requested a review from devmotion July 2, 2020 09:47
src/common.jl Outdated Show resolved Hide resolved
src/common.jl Outdated Show resolved Hide resolved
test/ad/distributions.jl Outdated Show resolved Hide resolved
test/ad/distributions.jl Outdated Show resolved Hide resolved
test/others.jl Outdated Show resolved Hide resolved
test/runtests.jl Outdated Show resolved Hide resolved
test/PosDef.jl Outdated
@@ -0,0 +1,28 @@
module PosDef
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to put it in a separate module? I would be surprised at least if the use of Requires and separate modules are needed to define the adjoint for Tracker.

I would have assumed that something like

if AD == "All" || AD == "Tracker"
    include("tracker_utils.jl")
end

in test/ad/distributions.jl would be sufficient.

Copy link
Collaborator Author

@nmheim nmheim Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that the TrackedMatrix overload is compiled even if Tracker is not loaded...
I tried it with Requires (and with include) in test/ad/distributions.jl which didn't work, but defining is_posdef in test/runtests.jl and requiring the Tracker stuff there does the job. So there is no more separate module, but there might be an obvious solution I am not seeing?

@nmheim
Copy link
Collaborator Author

nmheim commented Jul 2, 2020

There is a file src/for.jl that is not loaded in src/DistributionsAD.jl, that was introduced in this PR: #13

what to do with it?

@devmotion devmotion mentioned this pull request Jul 7, 2020
@yebai
Copy link
Member

yebai commented Jul 8, 2020

There is a file src/for.jl that is not loaded in src/DistributionsAD.jl, that was introduced in this PR: #13
what to do with it?

@nmheim These codes should probably live in the package Probability.jl. But for now, perhaps consider moving these code into a dedicated (temporary) subfolder, say combinators/?

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a few suggestions for the ChainRules overloads

@@ -3,14 +3,13 @@ uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything in ChainRules specifically that you need, or could you use ChainRulesCore, which is much lighter? There's also ChainRulesTestUtils.rrule_test, which makes it easy to test the rules directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, thanks a lot for all the suggestions! :)

ChainRules is needed for now, because we are using chol_blocked_rev here:

∂X = @thunk(ChainRules.chol_blocked_rev(f̄, factors, 25, true))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're welcome! That makes sense. I added an alternate suggestion below.

src/common.jl Outdated Show resolved Hide resolved
src/common.jl Outdated Show resolved Hide resolved
src/common.jl Outdated Show resolved Hide resolved
src/common.jl Outdated Show resolved Hide resolved
src/common.jl Outdated Show resolved Hide resolved
Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SInce you define turing_chol as

function turing_chol(A::AbstractMatrix, check)
    chol = cholesky(A, check=check)
    (chol.factors, chol.info)
end

couldn't you define its rrule as (something like)

function ChainRulesCore.rrule(::typeof(turing_chol), A::AbstractMatrix, check)
    chol, inner_pullback = rrule(cholesky, A, check=check)
    Ω = (chol.factors, chol.info)
    function turing_chol_pullback(ΔΩ)
        (Δfactors, Δinfo) = ΔΩ
        Δchol = Composite{typeof(chol)}(factors = Δfactors, info = Δinfo)
        (_, ∂A, ∂check) = inner_pullback(Δchol)
        return (NO_FIELDS, A, ∂check)
    end
    return Ω, turing_chol_pullback
end

? This would be a bit safer, as it doesn't use ChainRules internal functions.

The check argument isn't currently supported by ChainRules. I'll fix.

return upper(C.factors)
else
return copy(lower(C.factors)')
function ChainRules.rrule(::typeof(turing_chol), A::AbstractMatrix, check)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently chol_blocked_rev is only tested for real matrices. Probably the right thing to do is relax that constraint on ChainRules and test (I'll open a PR).

@yebai yebai mentioned this pull request Jul 20, 2020
@nmheim
Copy link
Collaborator Author

nmheim commented Jul 20, 2020

I will push the suggestions by @sethaxen once ChainRules can deal with keyword arguments. Then this should be ready.

@nmheim
Copy link
Collaborator Author

nmheim commented Jul 24, 2020

Ok, it will take some more time until the keyword stuff for the cholesky rule is fixed. Seth will ping me once thats done. I think we could merge this? @devmotion @mohamed82008

@nmheim nmheim merged commit af92680 into master Jul 26, 2020
@delete-merged-branch delete-merged-branch bot deleted the require-backends branch July 26, 2020 16:33
@yebai yebai mentioned this pull request Jul 30, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants