Skip to content
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.31.3"
version = "0.31.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
14 changes: 9 additions & 5 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,15 @@ Return a named tuple of parameters.
"""
getparams(model, t) = t.θ
function getparams(model::DynamicPPL.Model, vi::DynamicPPL.VarInfo)
# Want the end-user to receive parameters in constrained space, so we `link`.
vi = DynamicPPL.invlink(vi, model)

# Extract parameter values in a simple form from the `VarInfo`.
vals = DynamicPPL.values_as(vi, OrderedDict)
# NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
# Unfortunately, using `invlink` can cause issues in scenarios where the constraints
# of the parameters change depending on the realizations. Hence we have to use
# `values_as_in_model`, which re-runs the model and extracts the parameters
# as they are seen in the model, i.e. in the constrained space. Moreover,
# this means that the code below will work both of linked and invlinked `vi`.
# Ref: https://github.com/TuringLang/Turing.jl/issues/2195
# NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
vals = DynamicPPL.values_as_in_model(model, deepcopy(vi))

# Obtain an iterator over the flattened parameter names and values.
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down Expand Up @@ -43,6 +44,7 @@ DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.25.1"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10"
HypothesisTests = "0.11"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.4"
MCMCChains = "5, 6"
Expand Down
42 changes: 42 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,46 @@
@test mean(Array(chain)) ≈ 0.2
end
end

@turing_testset "issue: #2195" begin
@model function buggy_model()
lb ~ Uniform(0, 1)
ub ~ Uniform(1.5, 2)

# HACK: Necessary to avoid NUTS failing during adaptation.
try
x ~ transformed(Normal(0, 1), inverse(Bijectors.Logit(lb, ub)))
catch e
if e isa DomainError
Turing.@addlogprob! -Inf
return nothing
else
rethrow()
end
end
end

model = buggy_model();
num_samples = 1_000;

chain = sample(
model,
NUTS(),
num_samples;
initial_params=[0.5, 1.75, 1.0]
)
chain_prior = sample(model, Prior(), num_samples)

# Extract the `x` like this because running `generated_quantities` was how
# the issue was discovered, hence we also want to make sure that it works.
results = generated_quantities(model, chain)
results_prior = generated_quantities(model, chain_prior)

# Make sure none of the samples in the chains resulted in errors.
@test all(!isnothing, results)

# The discrepancies in the chains are in the tails, so we can't just compare the mean, etc.
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.01
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using ReverseDiff
using SpecialFunctions
using StatsBase
using StatsFuns
using HypothesisTests
using Tracker
using Turing
using Turing.Inference
Expand Down