Skip to content

Commit a303b9b

Browse files
authored
Merge branch 'master' into sunxd/move_ad
2 parents 67860e6 + cf647b1 commit a303b9b

File tree

8 files changed

+70
-85
lines changed

8 files changed

+70
-85
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.30.2"
3+
version = "0.30.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/Inference.jl

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,23 @@ DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.l
127127
# Algorithm for sampling from the prior
128128
struct Prior <: InferenceAlgorithm end
129129

130+
function AbstractMCMC.step(
131+
rng::Random.AbstractRNG,
132+
model::DynamicPPL.Model,
133+
sampler::DynamicPPL.Sampler{<:Prior},
134+
state=nothing;
135+
kwargs...,
136+
)
137+
vi = last(DynamicPPL.evaluate!!(
138+
model,
139+
VarInfo(),
140+
SamplingContext(
141+
rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()
142+
)
143+
))
144+
return vi, nothing
145+
end
146+
130147
"""
131148
mh_accept(logp_current::Real, logp_proposal::Real, log_proposal_ratio::Real)
132149
@@ -242,36 +259,6 @@ function AbstractMCMC.sample(
242259
chain_type=chain_type, progress=progress, kwargs...)
243260
end
244261

245-
function AbstractMCMC.sample(
246-
rng::AbstractRNG,
247-
model::AbstractModel,
248-
alg::Prior,
249-
ensemble::AbstractMCMC.AbstractMCMCEnsemble,
250-
N::Integer,
251-
n_chains::Integer;
252-
chain_type=DynamicPPL.default_chain_type(alg),
253-
progress=PROGRESS[],
254-
kwargs...
255-
)
256-
return AbstractMCMC.sample(rng, model, SampleFromPrior(), ensemble, N, n_chains;
257-
chain_type, progress, kwargs...)
258-
end
259-
260-
function AbstractMCMC.sample(
261-
rng::AbstractRNG,
262-
model::AbstractModel,
263-
alg::Prior,
264-
N::Integer;
265-
chain_type=DynamicPPL.default_chain_type(alg),
266-
resume_from=nothing,
267-
initial_state=DynamicPPL.loadstate(resume_from),
268-
progress=PROGRESS[],
269-
kwargs...
270-
)
271-
return AbstractMCMC.mcmcsample(rng, model, SampleFromPrior(), N;
272-
chain_type, initial_state, progress, kwargs...)
273-
end
274-
275262
##########################
276263
# Chain making utilities #
277264
##########################

src/mcmc/emcee.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function AbstractMCMC.step(
5858
vis[1],
5959
map(vis) do vi
6060
vi = DynamicPPL.link!!(vi, spl, model)
61-
AMH.Transition(vi[spl], getlogp(vi))
61+
AMH.Transition(vi[spl], getlogp(vi), false)
6262
end
6363
)
6464

src/mcmc/mh.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ function propose!!(
386386

387387
# Create a sampler and the previous transition.
388388
mh_sampler = AMH.MetropolisHastings(dt)
389-
prev_trans = AMH.Transition(vt, getlogp(vi))
389+
prev_trans = AMH.Transition(vt, getlogp(vi), false)
390390

391391
# Make a new transition.
392392
densitymodel = AMH.DensityModel(
@@ -421,7 +421,7 @@ function propose!!(
421421

422422
# Create a sampler and the previous transition.
423423
mh_sampler = AMH.MetropolisHastings(spl.alg.proposals)
424-
prev_trans = AMH.Transition(vals, getlogp(vi))
424+
prev_trans = AMH.Transition(vals, getlogp(vi), false)
425425

426426
# Make a new transition.
427427
densitymodel = AMH.DensityModel(

src/optimisation/Optimisation.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,17 +283,17 @@ function optim_function(
283283
model::Model,
284284
estimator::Union{MLE, MAP};
285285
constrained::Bool=true,
286-
autoad::Union{Nothing, AbstractADType}=NoAD(),
286+
adtype::Union{Nothing, AbstractADType}=NoAD(),
287287
)
288-
if autoad === nothing
289-
Base.depwarn("the use of `autoad=nothing` is deprecated, please use `autoad=SciMLBase.NoAD()`", :optim_function)
288+
if adtype === nothing
289+
Base.depwarn("the use of `adtype=nothing` is deprecated, please use `adtype=SciMLBase.NoAD()`", :optim_function)
290290
end
291291

292292
obj, init, t = optim_objective(model, estimator; constrained=constrained)
293293

294294
l(x, _) = obj(x)
295-
f = if autoad isa AbstractADType && autoad !== NoAD()
296-
OptimizationFunction(l, autoad)
295+
f = if adtype isa AbstractADType && adtype !== NoAD()
296+
OptimizationFunction(l, adtype)
297297
else
298298
OptimizationFunction(
299299
l;
@@ -310,10 +310,10 @@ function optim_problem(
310310
estimator::Union{MAP, MLE};
311311
constrained::Bool=true,
312312
init_theta=nothing,
313-
autoad::Union{Nothing, AbstractADType}=NoAD(),
313+
adtype::Union{Nothing, AbstractADType}=NoAD(),
314314
kwargs...,
315315
)
316-
f, init, transform = optim_function(model, estimator; constrained=constrained, autoad=autoad)
316+
f, init, transform = optim_function(model, estimator; constrained=constrained, adtype=adtype)
317317

318318
u0 = init_theta === nothing ? init() : init(init_theta)
319319
prob = OptimizationProblem(f, u0; kwargs...)

src/variational/advi.jl

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,17 @@
1-
# TODO(torfjelde): Find a better solution.
2-
struct Vec{N,B} <: Bijectors.Bijector
3-
b::B
4-
size::NTuple{N, Int}
5-
end
6-
7-
Bijectors.inverse(f::Vec) = Vec(Bijectors.inverse(f.b), f.size)
8-
9-
Bijectors.output_length(f::Vec, sz) = Bijectors.output_length(f.b, sz)
10-
Bijectors.output_length(f::Vec, n::Int) = Bijectors.output_length(f.b, n)
11-
12-
function Bijectors.with_logabsdet_jacobian(f::Vec, x)
13-
return Bijectors.transform(f, x), Bijectors.logabsdetjac(f, x)
14-
end
15-
16-
function Bijectors.transform(f::Vec, x::AbstractVector)
17-
# Reshape into shape compatible with wrapped bijector and then `vec` again.
18-
return vec(f.b(reshape(x, f.size)))
19-
end
20-
21-
function Bijectors.transform(f::Vec{N,<:Bijectors.Inverse}, x::AbstractVector) where N
22-
# Reshape into shape compatible with original (forward) bijector and then `vec` again.
23-
return vec(f.b(reshape(x, Bijectors.output_length(f.b.orig, prod(f.size)))))
24-
end
25-
26-
function Bijectors.transform(f::Vec, x::AbstractMatrix)
27-
# At the moment we do batching for higher-than-1-dim spaces by simply using
28-
# lists of inputs rather than `AbstractArray` with `N + 1` dimension.
29-
cols = Iterators.Stateful(eachcol(x))
30-
# Make `init` a matrix to ensure type-stability
31-
init = reshape(f(first(cols)), :, 1)
32-
return mapreduce(f, hcat, cols; init = init)
33-
end
34-
35-
function Bijectors.logabsdetjac(f::Vec, x::AbstractVector)
36-
return Bijectors.logabsdetjac(f.b, reshape(x, f.size))
37-
end
1+
# TODO: Move to Bijectors.jl if we find further use for this.
2+
"""
3+
wrap_in_vec_reshape(f, in_size)
384
39-
function Bijectors.logabsdetjac(f::Vec, x::AbstractMatrix)
40-
return map(eachcol(x)) do x_
41-
Bijectors.logabsdetjac(f, x_)
42-
end
5+
Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces
6+
a vector of length `prod(Bijectors.output(f, in_size))`.
7+
"""
8+
function wrap_in_vec_reshape(f, in_size)
9+
vec_in_length = prod(in_size)
10+
reshape_inner = Bijectors.Reshape((vec_in_length,), in_size)
11+
out_size = Bijectors.output_size(f, in_size)
12+
vec_out_length = prod(out_size)
13+
reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,))
14+
return reshape_outer f reshape_inner
4315
end
4416

4517

@@ -83,7 +55,7 @@ function Bijectors.bijector(
8355
if d isa Distributions.UnivariateDistribution
8456
b
8557
else
86-
Vec(b, size(d))
58+
wrap_in_vec_reshape(b, size(d))
8759
end
8860
end
8961

@@ -106,7 +78,10 @@ meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model)
10678
function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model)
10779
# Setup.
10880
varinfo = DynamicPPL.VarInfo(model)
109-
num_params = length(varinfo[DynamicPPL.SampleFromPrior()])
81+
# Use linked `varinfo` to determine the correct number of parameters.
82+
# TODO: Replace with `length` once this is implemented for `VarInfo`.
83+
varinfo_linked = DynamicPPL.link(varinfo, model)
84+
num_params = length(varinfo_linked[:])
11085

11186
# initial params
11287
μ = randn(rng, num_params)
@@ -134,7 +109,10 @@ function AdvancedVI.update(
134109
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal},
135110
θ::AbstractArray,
136111
)
137-
μ, ω = θ[1:length(td)], θ[length(td) + 1:end]
112+
# `length(td.dist) != length(td)` if `td.transform` changes the dimensionality,
113+
# so we need to use the length of the underlying distribution `td.dist` here.
114+
# TODO: Check if we can get away with `view` instead of `getindex` for all AD backends.
115+
μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end]
138116
return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω))
139117
end
140118

test/mcmc/Inference.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,21 @@
140140
@test all(haskey(x, :lp) for x in chains)
141141
@test mean(x[:s][1] for x in chains) 3 atol=0.1
142142
@test mean(x[:m][1] for x in chains) 0 atol=0.1
143+
144+
@testset "#2169" begin
145+
# Not exactly the same as the issue, but similar.
146+
@model function issue2169_model()
147+
if DynamicPPL.leafcontext(__context__) isa DynamicPPL.PriorContext
148+
x ~ Normal(0, 1)
149+
else
150+
x ~ Normal(1000, 1)
151+
end
152+
end
153+
154+
model = issue2169_model()
155+
chain = sample(model, Prior(), 10)
156+
@test all(mean(chain[:x]) .< 5)
157+
end
143158
end
144159

145160
@testset "chain ordering" begin

test/variational/advi.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,5 +64,10 @@
6464
x0_inv = inverse(b)(z0)
6565
@test size(x0_inv) == size(x0)
6666
@test all(x0 .≈ x0_inv)
67+
68+
# And regression for https://github.com/TuringLang/Turing.jl/issues/2160.
69+
q = vi(m, ADVI(10, 1000))
70+
x = rand(q, 1000)
71+
@test mean(eachcol(x)) [0.5, 0.5] atol=0.1
6772
end
6873
end

0 commit comments

Comments
 (0)