Skip to content

Commit

Permalink
Merge pull request #95 from TuringLang/require-backends
Browse files Browse the repository at this point in the history
Require backends

Separates the code of the four AD backends into four files such that they can be loaded optionally.
  • Loading branch information
nmheim authored Jul 26, 2020
2 parents 809bdc7 + a4c7254 commit af92680
Show file tree
Hide file tree
Showing 21 changed files with 1,024 additions and 835 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: ForwardDiff and Tracker tests
name: ForwardDiff tests

on:
push:
Expand Down Expand Up @@ -29,4 +29,4 @@ jobs:
- uses: julia-actions/julia-runtest@latest
env:
GROUP: AD
AD: ForwardDiff_Tracker
AD: ForwardDiff
32 changes: 32 additions & 0 deletions .github/workflows/Tracker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Tracker tests

on:
push:
branches:
- master
pull_request:

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
version:
- '1.0'
- '1'
os:
- ubuntu-latest
- macOS-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
env:
GROUP: AD
AD: Tracker
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
Manifest.toml
Manifest.toml
*.swp
12 changes: 4 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -19,33 +18,30 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Combinatorics = "0.7, 1.0"
Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.23.3"
FillArrays = "0.8"
ForwardDiff = "0.10.6"
MacroTools = "0.5"
NaNMath = "0.3"
PDMats = "0.9, 0.10"
Requires = "1"
SpecialFunctions = "0.8, 0.9, 0.10"
StaticArrays = "0.12"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
ZygoteRules = "0.2"
julia = "1"

[extras]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["FiniteDifferences", "Test", "ReverseDiff", "Zygote"]
test = ["Combinatorics", "FiniteDifferences", "Test", "ReverseDiff", "Zygote", "Tracker"]
38 changes: 28 additions & 10 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
module DistributionsAD

using PDMats,
ForwardDiff,
LinearAlgebra,
Distributions,
Random,
Combinatorics,
SpecialFunctions,
StatsFuns,
Compat,
Requires
Requires,
ZygoteRules,
ChainRules, # needed for `ChainRules.chol_blocked_rev`
FillArrays

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, @grad, data
using SpecialFunctions: logabsgamma, digamma
using LinearAlgebra: copytri!, AbstractTriangular
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
using Base.Iterators: drop

import StatsFuns: logsumexp,
Expand All @@ -35,7 +32,6 @@ import Distributions: MvNormal,
Binomial,
BetaBinomial,
Erlang
import ZygoteRules

export TuringScalMvNormal,
TuringDiagMvNormal,
Expand All @@ -55,8 +51,30 @@ include("matrixvariate.jl")
include("flatten.jl")
include("arraydist.jl")
include("filldist.jl")
@init @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("reversediff.jl")

include("zygote.jl")

@init begin
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
include("forwarddiff.jl")

# loads adjoint for `poissonbinomial_pdf_fft`
include("zygote_forwarddiff.jl")
end

@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("reversediff.jl")
end

@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
using DiffRules
using SpecialFunctions
using LinearAlgebra: AbstractTriangular
using .Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix,
TrackedArray, TrackedVecOrMat, track, @grad, data
include("tracker.jl")
end
end

end
Loading

0 comments on commit af92680

Please sign in to comment.