Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 0.39.5

Fixed a bug where sampling with an `externalsampler` would not set the log probability density inside the resulting chain.
Note that there are still potentially bugs with the log-Jacobian term not being correctly included.
A fix is being worked on.

# 0.39.4

Bumped compatibility of AbstractPPL to include 0.12.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.39.4"
version = "0.39.5"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
85 changes: 64 additions & 21 deletions src/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,23 @@ The `Unconstrained` type-parameter is to indicate whether the sampler requires u

# Fields
$(TYPEDFIELDS)

# Turing.jl's interface for external samplers

When implementing a new `MySampler <: AbstractSampler`,
`MySampler` must first and foremost conform to the `AbstractMCMC` interface to work with Turing.jl's `externalsampler` function.
In particular, it must implement:

- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is documented in AbstractMCMC.jl)
- `AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`: How to extract the parameters from the state returned by your sampler (i.e., the second return value of `step`).

There are a few more optional functions which you can implement to improve the integration with Turing.jl:

- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as a component in Turing's Gibbs sampler, you should make this evaluate to `true`.

- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires unconstrained space, you should return `true`. This tells Turing to perform linking on the VarInfo before evaluation, and ensures that the parameter values passed to your sampler will always be in unconstrained (Euclidean) space.

- `Turing.Inference.getlogp_external(external_transition, external_state)`: Tell Turing how to extract the log probability density associated with this transition (and state). If you do not specify these, Turing will simply re-evaluate the model with the parameters obtained from `getparams`, which can be inefficient. It is therefore recommended to store the log probability density in either the transition or the state (or both) and override this method.
"""
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
InferenceAlgorithm
Expand Down Expand Up @@ -68,30 +85,28 @@ function externalsampler(
return ExternalSampler(sampler, adtype, Val(unconstrained))
end

struct TuringState{S,M,V,C}
"""
getlogp_external(external_transition, external_state)

Get the log probability density associated with the external sampler's
transition and state. Returns `missing` by default; in this case, an extra
model evaluation will be needed to calculate the correct log density.
"""
getlogp_external(::Any, ::Any) = missing
getlogp_external(mh::AdvancedMH.Transition, ::AdvancedMH.Transition) = mh.lp
getlogp_external(hmc::AdvancedHMC.Transition, ::AdvancedHMC.HMCState) = hmc.stat.log_density

struct TuringState{S,V1<:AbstractVarInfo,M,V,C}
state::S
# Note that this varinfo has the correct parameters and logp obtained from
# the state, whereas `ldf.varinfo` will in general have junk inside it.
varinfo::V1
ldf::DynamicPPL.LogDensityFunction{M,V,C}
end

state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
# TODO: We should probably rename this `getparams` since it returns something
# very different from `Turing.Inference.getparams`.
θ = getparams(f.model, transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
return Transition(f.model, varinfo, transition)
end

function varinfo(state::TuringState)
θ = getparams(state.ldf.model, state.state)
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(state.ldf.varinfo, θ)
end
varinfo(state::TuringState) = state.varinfo
varinfo(state::AbstractVarInfo) = state

# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
return getparams(model, state.transition)
Expand All @@ -100,6 +115,21 @@ getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params

function make_updated_varinfo(
f::DynamicPPL.LogDensityFunction, external_transition, external_state
)
# Set the parameters.
new_parameters = getparams(f.model, external_state)
new_varinfo = DynamicPPL.unflatten(f.varinfo, new_parameters)
# Set (or recalculate, if needed) the log density.
new_logp = getlogp_external(external_transition, external_state)
return if ismissing(new_logp)
last(DynamicPPL.evaluate!!(f.model, new_varinfo, f.context))
else
DynamicPPL.setlogp!!(new_varinfo, new_logp)
end
end

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand Down Expand Up @@ -143,8 +173,15 @@ function AbstractMCMC.step(
kwargs...,
)
end

# Get the parameters and log density, and set them in the varinfo.
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)

# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
return (
Transition(f.model, new_varinfo, transition_inner),
TuringState(state_inner, new_varinfo, f),
)
end

function AbstractMCMC.step(
Expand All @@ -157,11 +194,17 @@ function AbstractMCMC.step(
sampler = sampler_wrapper.alg.sampler
f = state.ldf

# Then just call `AdvancedHMC.step` with the right arguments.
# Then just call `AdvancedMCMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
)

# Get the parameters and log density, and set them in the varinfo.
new_varinfo = make_updated_varinfo(f, transition_inner, state_inner)

# Update the `state`
return transition_to_turing(f, transition_inner), state_to_turing(f, state_inner)
return (
Transition(f.model, new_varinfo, transition_inner),
TuringState(state_inner, new_varinfo, f),
)
end
5 changes: 3 additions & 2 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ isgibbscomponent(::PG) = true
isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler)

isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler)
isgibbscomponent(::AdvancedHMC.HMC) = true
isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
isgibbscomponent(spl) = false

function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf
Expand Down Expand Up @@ -561,7 +562,7 @@ function setparams_varinfo!!(
new_inner_state = setparams_varinfo!!(
AbstractMCMC.LogDensityModel(logdensity), sampler, state.state, params
)
return TuringState(new_inner_state, logdensity)
return TuringState(new_inner_state, params, logdensity)
end

function setparams_varinfo!!(
Expand Down
21 changes: 21 additions & 0 deletions test/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,17 @@ end
)
end
end

@testset "logp is set correctly" begin
@model logp_check() = x ~ Normal()
model = logp_check()
sampler = initialize_nuts(model)
sampler_ext = externalsampler(
sampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained=true
)
chn = sample(logp_check(), Gibbs(@varname(x) => sampler_ext), 100)
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
end
end

@testset "AdvancedMH.jl" begin
Expand All @@ -167,7 +178,17 @@ end
)
end
end

@testset "logp is set correctly" begin
@model logp_check() = x ~ Normal()
model = logp_check()
sampler = initialize_mh_rw(model)
sampler_ext = externalsampler(sampler; unconstrained=true)
chn = sample(logp_check(), Gibbs(@varname(x) => sampler_ext), 100)
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
end
end

# NOTE: Broken because MH doesn't really follow the `logdensity` interface, but calls
# it with `NamedTuple` instead of `AbstractVector`.
# @testset "MH with prior proposal" begin
Expand Down
9 changes: 9 additions & 0 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,14 @@ end
end

@testset "externalsampler" begin
function check_logp_correct(sampler)
@testset "logp is set correctly" begin
@model logp_check() = x ~ Normal()
chn = sample(logp_check(), Gibbs(@varname(x) => sampler), 100)
@test isapprox(logpdf.(Normal(), chn[:x]), chn[:lp])
end
end

@model function demo_gibbs_external()
m1 ~ Normal()
m2 ~ Normal()
Expand All @@ -851,6 +859,7 @@ end
model, sampler, 1000; discard_initial=1000, thinning=10, n_adapts=0
)
check_numerical(chain, [:m1, :m2], [-0.2, 0.6]; atol=0.1)
check_logp_correct(sampler_inner)
end
end

Expand Down
Loading