Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -12,12 +13,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -28,13 +27,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "1.0.1"
AdvancedHMC = "0.2.25"
AdvancedMH = "0.5.1"
AdvancedVI = "0.1"
Bijectors = "0.8.2"
Distributions = "0.23.8"
DistributionsAD = "0.6.3"
Expand All @@ -45,13 +45,13 @@ Libtask = "0.4.1"
LogDensityProblems = "0.10.3"
MCMCChains = "4.0.4"
MacroTools = "0.5.5"
PDMats = "0.10"
ProgressLogging = "0.1.3"
NamedArrays = "0.9"
Reexport = "0.2"
Requires = "1.0.1"
SpecialFunctions = "0.10.3"
StatsBase = "0.33"
StatsFuns = "0.9.5"
Tracker = "0.2.11"
Zygote = "0.5.4"
ZygoteRules = "0.2"
julia = "1.3"
13 changes: 12 additions & 1 deletion test/Turing/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ using Libtask
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC, Bijectors
using Tracker: Tracker

import AdvancedVI
import DynamicPPL: getspace, NoDist, NamedDist

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
@info "[Turing]: progress logging is $(switch ? "enabled" : "disabled") globally"
PROGRESS[] = switch
AdvancedVI.turnprogress(switch)
end

# Random probability measures.
Expand Down Expand Up @@ -64,6 +66,9 @@ end
###########
# Exports #
###########
# `using` statements for stuff to re-export
using DynamicPPL: elementwise_loglikelihoods, generated_quantities, logprior, logjoint
using StatsBase: predict

# Turing essentials - modelling macros and inference algorithms
export @model, # modelling
Expand Down Expand Up @@ -114,5 +119,11 @@ export @model, # modelling
LogPoisson,
NamedDist,
filldist,
arraydist
arraydist,

predict,
elementwise_loglikelihoods,
genereated_quantities,
logprior,
logjoint
end
4 changes: 2 additions & 2 deletions test/Turing/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ end
kwargs...
)
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
if resume_from === nothing
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
Expand All @@ -149,7 +149,7 @@ function AbstractMCMC.sample(
kwargs...
)
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains;
chain_type=chain_type, progress=false, kwargs...)
Expand Down
10 changes: 6 additions & 4 deletions test/Turing/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ using DynamicPPL: Model, AbstractSampler, Sampler, SampleFromPrior
using LinearAlgebra: copytri!
using Bijectors: PDMatDistribution
import Bijectors: link, invlink
using AdvancedVI
using StatsFuns: logsumexp, softmax
@reexport using DynamicPPL
using Requires

import ZygoteRules

include("container.jl")
include("ad.jl")
include("deprecations.jl")

function __init__()
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("compat/zygote.jl")
export ZygoteAD
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
include("compat/reversediff.jl")
export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache
Expand Down Expand Up @@ -50,6 +51,7 @@ export @model,
setadsafe,
ForwardDiffAD,
TrackerAD,
ZygoteAD,
value,
gradient_logp,
CHUNKSIZE,
Expand Down
62 changes: 47 additions & 15 deletions test/Turing/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
##############################
const ADBACKEND = Ref(:forwarddiff)
setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym))
function setadbackend(::Val{:forward_diff})
Base.depwarn("`Turing.setadbackend(:forward_diff)` is deprecated. Please use `Turing.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend)
setadbackend(Val(:forwarddiff))
function setadbackend(backend::Val)
_setadbackend(backend)
AdvancedVI.setadbackend(backend)
Bijectors.setadbackend(backend)
end
function setadbackend(::Val{:forwarddiff})

function _setadbackend(::Val{:forwarddiff})
CHUNKSIZE[] == 0 && setchunksize(40)
ADBACKEND[] = :forwarddiff
end

function setadbackend(::Val{:reverse_diff})
Base.depwarn("`Turing.setadbackend(:reverse_diff)` is deprecated. Please use `Turing.setadbackend(:tracker)` to use `Tracker` or `Turing.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend)
setadbackend(Val(:tracker))
end
function setadbackend(::Val{:tracker})
function _setadbackend(::Val{:tracker})
ADBACKEND[] = :tracker
end
function _setadbackend(::Val{:zygote})
ADBACKEND[] = :zygote
end

const ADSAFE = Ref(false)
function setadsafe(switch::Bool)
Expand All @@ -42,12 +42,14 @@ getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg)
getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[]

struct TrackerAD <: ADBackend end
struct ZygoteAD <: ADBackend end

ADBackend() = ADBackend(ADBACKEND[])
ADBackend(T::Symbol) = ADBackend(Val(T))

ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]}
ADBackend(::Val{:tracker}) = TrackerAD
ADBackend(::Val{:zygote}) = ZygoteAD
ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.")

"""
Expand All @@ -56,13 +58,15 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t
Find the autodifferentiation backend of the algorithm `alg`.
"""
getADbackend(spl::Sampler) = getADbackend(spl.alg)
getADbackend(spl::SampleFromPrior) = ADBackend()()

"""
gradient_logp(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
sampler::AbstractSampler,
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)

Computes the value of the log joint of `θ` and its gradient for the model
Expand All @@ -73,9 +77,10 @@ function gradient_logp(
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::Sampler
sampler::AbstractSampler,
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
return gradient_logp(getADbackend(sampler), θ, vi, model, sampler)
return gradient_logp(getADbackend(sampler), θ, vi, model, sampler, ctx)
end

"""
Expand All @@ -85,6 +90,7 @@ gradient_logp(
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)

Compute the value of the log joint of `θ` and its gradient for the model
Expand All @@ -96,12 +102,13 @@ function gradient_logp(
vi::VarInfo,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
# Define function to compute log joint.
logp_old = getlogp(vi)
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler)
model(new_vi, sampler, ctx)
logp = getlogp(new_vi)
setlogp!(vi, ForwardDiff.value(logp))
return logp
Expand All @@ -123,13 +130,14 @@ function gradient_logp(
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler)
model(new_vi, sampler, ctx)
return getlogp(new_vi)
end

Expand All @@ -141,6 +149,30 @@ function gradient_logp(
return l, ∂l∂θ
end

function gradient_logp(
backend::ZygoteAD,
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, context)
return getlogp(new_vi)
end

# Compute forward and reverse passes.
l::T, ȳ = ZygoteRules.pullback(f, θ)
∂l∂θ::typeof(θ) = ȳ(1)[1]

return l, ∂l∂θ
end

function verifygrad(grad::AbstractVector{<:Real})
if any(isnan, grad) || any(isinf, grad)
@warn("Numerical error in gradients. Rejecting current proposal...")
Expand Down
28 changes: 13 additions & 15 deletions test/Turing/core/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function emptyrdcache end

getrdcache() = RDCache[]
ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()}
function setadbackend(::Val{:reversediff})
function _setadbackend(::Val{:reversediff})
ADBACKEND[] = :reversediff
end

Expand All @@ -20,13 +20,14 @@ function gradient_logp(
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler)
model(new_vi, sampler, context)
return getlogp(new_vi)
end
tp, result = taperesult(f, θ)
Expand All @@ -45,25 +46,24 @@ end
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
setrdcache(::Val{true}) = RDCache[] = true
function emptyrdcache()
for k in keys(Memoization.caches)
if k[1] === typeof(memoized_taperesult)
pop!(Memoization.caches, k)
end
end
Memoization.empty_cache!(memoized_taperesult)
return
end

function gradient_logp(
backend::ReverseDiffAD{true},
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler)
model(new_vi, sampler, context)
return getlogp(new_vi)
end
ctp, result = memoized_taperesult(f, θ)
Expand All @@ -79,15 +79,13 @@ end
f::F
x::Tx
end
function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any})
function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any})
key = keys[1][1]
return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x)))
return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid()))
end
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
Memoization.@memoize function memoized_taperesult(k::RDTapeKey)
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
return compiledtape(k.f, k.x), GradientResult(k.x)
end
memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x))
Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x)
compiledtape(f, x) = compile(GradientTape(f, x))
end
28 changes: 0 additions & 28 deletions test/Turing/core/compat/zygote.jl

This file was deleted.

Loading