Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport of ADTypes.jl compat and others #55

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.1.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -16,22 +18,30 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[weakdeps]
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
AdvancedVIReverseDiffExt = ["ReverseDiff"]
AdvancedVIZygoteExt = ["Zygote"]

[compat]
Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10"
ADTypes = "0.2, 1"
Bijectors = "0.4.0, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.13"
Distributions = "0.21, 0.22, 0.23, 0.24, 0.25"
DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6"
DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10.3"
ProgressMeter = "1.0.0"
Requires = "0.5, 1.0"
ReverseDiff = "1"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9, 1"
Tracker = "0.2.3"
Zygote = "0.6"
julia = "1"

[extras]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Pkg", "Test"]
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
13 changes: 13 additions & 0 deletions ext/AdvancedVIFluxExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module AdvancedVIFluxExt

if isdefined(Base, :get_extension)
using AdvancedVI: AdvancedVI
using Flux: Flux
else
using ..AdvancedVI: AdvancedVI
using ..Flux: Flux
end

AdvancedVI.apply!(o::Flux.Optimise.AbstractOptimizer, x, Δ) = Flux.Optimise.apply!(o, x, Δ)

end
40 changes: 40 additions & 0 deletions ext/AdvancedVIReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module AdvancedVIReverseDiffExt

if isdefined(Base, :get_extension)
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using ReverseDiff: ReverseDiff
else
using ..AdvancedVI: ADTypes, AdvancedVI
using ..ReverseDiff: ReverseDiff
end

AdvancedVI.ADBackend(::Val{:reversediff}) = ADTypes.AutoReverseDiff()

function AdvancedVI.setadbackend(::Val{:reversediff})
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
AdvancedVI.ADBACKEND[] = :reversediff
end

tape(f, x) = ReverseDiff.GradientTape(f, x)

function AdvancedVI.grad!(
vo,
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoReverseDiff},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) =
if (q isa Distributions.Distribution)
-vo(alg, AdvancedVI.update(q, θ), model, args...)
else
-vo(alg, q(θ), model, args...)
end
tp = tape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
return out
end

end
39 changes: 39 additions & 0 deletions ext/AdvancedVIZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
module AdvancedVIZygoteExt

if isdefined(Base, :get_extension)
using AdvancedVI: AdvancedVI, ADTypes, DiffResults, Distributions
using Zygote: Zygote
else
using ..AdvancedVI: ADTypes, AdvancedVI
using ..Zygote: Zygote
end

AdvancedVI.ADBackend(::Val{:zygote}) = ADTypes.AutoZygote()
function AdvancedVI.setadbackend(::Val{:zygote})
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
AdvancedVI.ADBACKEND[] = :zygote
end

function AdvancedVI.grad!(
vo,
alg::AdvancedVI.VariationalInference{<:ADTypes.AutoZygote},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) =
if (q isa Distributions.Distribution)
-vo(alg, AdvancedVI.update(q, θ), model, args...)
else
-vo(alg, q(θ), model, args...)
end
y, back = Zygote.pullback(f, θ)
dy = first(back(1.0))
DiffResults.value!(out, y)
DiffResults.gradient!(out, dy)
return out
end

end
93 changes: 25 additions & 68 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
module AdvancedVI

using Random: AbstractRNG
using Random: Random, AbstractRNG

using Distributions, DistributionsAD, Bijectors
using DocStringExtensions

using ProgressMeter, LinearAlgebra

using ForwardDiff
using Tracker
using ADTypes: ADTypes
using DiffResults: DiffResults

using ForwardDiff: ForwardDiff
using Tracker: Tracker

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
Expand All @@ -18,65 +21,6 @@ end

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0")))

include("ad.jl")

using Requires
function __init__()
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin
apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ)
Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ)
Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ)
end
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("compat/zygote.jl")
export ZygoteAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.ZygoteAD},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
y, back = Zygote.pullback(f, θ)
dy = first(back(1.0))
DiffResults.value!(out, y)
DiffResults.gradient!(out, dy)
return out
end
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("compat/reversediff.jl")
export ReverseDiffAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
tp = AdvancedVI.tape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
return out
end
end
end

export
vi,
ADVI,
Expand All @@ -86,10 +30,12 @@ export
DecayedADAGrad,
VariationalInference

include("compat.jl")
include("ad.jl")

abstract type VariationalInference{AD} end

getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD)
getADtype(::VariationalInference{AD}) where AD = AD
getchunksize(::ADTypes.AutoForwardDiff{chunk}) where chunk = chunk === nothing ? 0 : chunk

abstract type VariationalObjective end

Expand All @@ -100,7 +46,7 @@ const VariationalPosterior = Distribution{Multivariate, Continuous}
grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)

Computes the gradients used in `optimize!`. Default implementation is provided for
`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
`VariationalInference{AD}` where `AD` is either `ADTypes.AutoForwardDiff` or `ADTypes.AutoTracker`.
This implicitly also gives a default implementation of `optimize!`.

Variance reduction techniques, e.g. control variates, should be implemented in this function.
Expand Down Expand Up @@ -129,7 +75,7 @@ function update end
# default implementations
function grad!(
vo,
alg::VariationalInference{<:ForwardDiffAD},
alg::VariationalInference{<:ADTypes.AutoForwardDiff},
q,
model,
θ::AbstractVector{<:Real},
Expand All @@ -143,7 +89,7 @@ function grad!(
end

# Set chunk size and do ForwardMode.
chunk_size = getchunksize(typeof(alg))
chunk_size = getchunksize(alg.adtype)
config = if chunk_size == 0
ForwardDiff.GradientConfig(f, θ)
else
Expand All @@ -154,7 +100,7 @@ end

function grad!(
vo,
alg::VariationalInference{<:TrackerAD},
alg::VariationalInference{<:ADTypes.AutoTracker},
q,
model,
θ::AbstractVector{<:Real},
Expand Down Expand Up @@ -238,4 +184,15 @@ include("optimisers.jl")
# VI algorithms
include("advi.jl")

@static if !isdefined(Base, :get_extension)
function __init__()
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/AdvancedVIReverseDiffExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/AdvancedVIZygoteExt.jl"
)
end
end

end # module
16 changes: 6 additions & 10 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
##############################
# Global variables/constants #
##############################
# FIXME: All this should go away.
const ADBACKEND = Ref(:forwarddiff)
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
function setadbackend(::Val{:forward_diff})
Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
setadbackend(Val(:forwarddiff))
end
function setadbackend(::Val{:forwarddiff})
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
ADBACKEND[] = :forwarddiff
end

Expand All @@ -16,6 +15,7 @@ function setadbackend(::Val{:reverse_diff})
setadbackend(Val(:tracker))
end
function setadbackend(::Val{:tracker})
Base.depwarn("`setadbackend` is deprecated. Please pass a `ADTypes.AbstractADType` as a keyword argument to the VI algorithm.", :setadbackend)
ADBACKEND[] = :tracker
end

Expand All @@ -32,15 +32,11 @@ function setchunksize(chunk_size::Int)
CHUNKSIZE[] = chunk_size
end

abstract type ADBackend end
struct ForwardDiffAD{chunk} <: ADBackend end
getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk

struct TrackerAD <: ADBackend end
getchunksize(::Type{<:ADTypes.AutoForwardDiff{chunk}}) where chunk = chunk

ADBackend() = ADBackend(ADBACKEND[])
ADBackend(T::Symbol) = ADBackend(Val(T))

ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
ADBackend(::Val{:tracker}) = TrackerAD
ADBackend(::Val{:forwarddiff}) = ADTypes.AutoForwardDiff(chunksize=CHUNKSIZE[])
ADBackend(::Val{:tracker}) = ADTypes.AutoTracker()
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")
7 changes: 5 additions & 2 deletions src/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ struct ADVI{AD} <: VariationalInference{AD}
samples_per_step::Int
"Maximum number of gradient steps."
max_iters::Int
"AD backend used for automatic differentiation."
adtype::AD

end

function ADVI(samples_per_step::Int=1, max_iters::Int=1000)
return ADVI{ADBackend()}(samples_per_step, max_iters)
function ADVI(samples_per_step::Int=1, max_iters::Int=1000; adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff())
return ADVI(samples_per_step, max_iters, adtype)
end

alg_str(::ADVI) = "ADVI"
Expand Down
Loading
Loading