Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require backends #95

Merged
merged 29 commits into from
Jul 26, 2020
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ca03858
move AD glue to separate required files
nmheim Jun 30, 2020
d1a7cfe
comment out all tests except for univariates
nmheim Jun 30, 2020
405dac2
readd Compat
nmheim Jun 30, 2020
0fde1f5
add combinatorics to test deps
nmheim Jun 30, 2020
dcb2959
remove adjoint for logabsgamma
nmheim Jul 1, 2020
25046c8
fix require warnings
nmheim Jul 1, 2020
aca8a1a
import -> using
nmheim Jul 1, 2020
64091fd
remove Zygote dependency from Tracker backend
nmheim Jul 1, 2020
cb46ece
Zygote -> ZygoteRules
nmheim Jul 1, 2020
93b147b
`...turing_chol` pullbacks
nmheim Jul 1, 2020
947bd0e
start separating tests for different backends
nmheim Jul 1, 2020
cdd29e8
some test fixes
nmheim Jul 1, 2020
ca03393
add mvcategorical and multvariate
nmheim Jul 2, 2020
6f40587
test multivariate
nmheim Jul 2, 2020
77d04bd
Simplify backend loading
nmheim Jul 2, 2020
c247251
move to_posdef to own module so that Tracker overloads can be used co…
nmheim Jul 2, 2020
ac625f5
Merge branch 'require-backends' of github.com:TuringLang/Distribution…
nmheim Jul 2, 2020
d75990f
fix tracker rrules
nmheim Jul 2, 2020
ca62a31
Merge branch 'master' into require-backends
nmheim Jul 2, 2020
47ee2e2
temporarily fix backend loading
nmheim Jul 2, 2020
47d2b55
requires works in runtests but not in ad/distributions?
nmheim Jul 2, 2020
ac35221
fully separate backend tests
nmheim Jul 2, 2020
5607f14
matrix variates
nmheim Jul 2, 2020
b1484f9
product dists
nmheim Jul 2, 2020
20b7e5c
separate Tracker/ForwardDiff actions
nmheim Jul 2, 2020
09d6cd4
move adjoint from zygote back
nmheim Jul 2, 2020
cdb5c3c
Revert "move adjoint from zygote back"
nmheim Jul 2, 2020
5ec5201
comments on poissonbinomial_pdf_fft
nmheim Jul 2, 2020
a4c7254
fix ChainRules overloads as suggested by @sethaxen
nmheim Jul 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
Manifest.toml
Manifest.toml
*.swp
12 changes: 4 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anything in ChainRules specifically that you need, or could you use ChainRulesCore, which is much lighter? There's also ChainRulesTestUtils.rrule_test, which makes it easy to test the rules directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, thanks a lot for all the suggestions! :)

ChainRules is needed for now, because we are using chol_blocked_rev here:

∂X = @thunk(ChainRules.chol_blocked_rev(f̄, factors, 25, true))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're welcome! That makes sense. I added an alternate suggestion below.

Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -19,33 +18,30 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Combinatorics = "0.7, 1.0"
Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.23.3"
FillArrays = "0.8"
ForwardDiff = "0.10.6"
MacroTools = "0.5"
NaNMath = "0.3"
PDMats = "0.9, 0.10"
Requires = "1"
SpecialFunctions = "0.8, 0.9, 0.10"
StaticArrays = "0.12"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
ZygoteRules = "0.2"
julia = "1"

[extras]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["FiniteDifferences", "Test", "ReverseDiff", "Zygote"]
test = ["Combinatorics", "FiniteDifferences", "Test", "ReverseDiff", "Zygote", "Tracker"]
46 changes: 32 additions & 14 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
module DistributionsAD

using PDMats,
ForwardDiff,
LinearAlgebra,
Distributions,
Random,
Combinatorics,
SpecialFunctions,
StatsFuns,
Compat,
Requires
Requires,
ZygoteRules,
ChainRules

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, @grad, data
using SpecialFunctions: logabsgamma, digamma
using LinearAlgebra: copytri!, AbstractTriangular
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
using Base.Iterators: drop

import StatsFuns: logsumexp,
Expand All @@ -35,28 +32,49 @@ import Distributions: MvNormal,
Binomial,
BetaBinomial,
Erlang
import ZygoteRules

export TuringScalMvNormal,
TuringDiagMvNormal,
TuringDenseMvNormal,
TuringMvLogNormal,
TuringPoissonBinomial,
TuringWishart,
TuringInverseWishart,
arraydist,
filldist
TuringPoissonBinomial
#TuringWishart,
#TuringInverseWishart,
#arraydist,
#filldist

include("common.jl")
include("univariate.jl")
include("multivariate.jl")
include("mvcategorical.jl")
#=
include("matrixvariate.jl")
include("flatten.jl")
include("arraydist.jl")
include("filldist.jl")
@init @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("reversediff.jl")
=#

include("zygote.jl")

function __init__()
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
include("forwarddiff.jl")
include("zygote_forwarddiff.jl")
end

@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("reversediff.jl")
end

@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
using DiffRules
using SpecialFunctions
using LinearAlgebra: AbstractTriangular
using .Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix,
TrackedArray, TrackedVecOrMat, track, @grad, data
include("tracker.jl")
end
end

end
Loading