Skip to content

Commit e4c5bbd

Browse files
penelopeysmmhauru
andauthored
Fix DynamicPPL / MCMCChains methods (#1076)
* Reimplement pointwise_logdensities (almost completely) * Move logjoint, logprior, ... as well * Fix imports, etc * Remove tests that are failing (yes I learnt this from Claude) * Changelog * logpdf * fix docstrings * allow dict output * changelog * fix some comments * fix tests * Fix more imports * Remove stray n Co-authored-by: Markus Hauru <markus@mhauru.org> * Expand `logprior`, `loglikelihood`, and `logjoint` docstrings --------- Co-authored-by: Markus Hauru <markus@mhauru.org>
1 parent 9bd8f16 commit e4c5bbd

File tree

8 files changed

+332
-562
lines changed

8 files changed

+332
-562
lines changed

HISTORY.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t
6161
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
6262
`loadstate` is exported from DynamicPPL.
6363

64-
### Change of default keytype of `pointwise_logdensities`
64+
### Change of output type for `pointwise_logdensities`
6565

66-
The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`.
66+
The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` when called on `MCMCChains.Chains` objects, now return new `MCMCChains.Chains` objects by default, instead of dictionaries of matrices.
67+
68+
If you want the old behaviour, you can pass `OrderedDict` as the third argument, i.e., `pointwise_logdensities(model, chain, OrderedDict)`.
6769

6870
**Other changes**
6971

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,4 +292,290 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
292292
end
293293
end
294294

295+
"""
296+
DynamicPPL.pointwise_logdensities(
297+
model::DynamicPPL.Model,
298+
chain::MCMCChains.Chains,
299+
::Type{Tout}=MCMCChains.Chains
300+
::Val{whichlogprob}=Val(:both),
301+
)
302+
303+
Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where
304+
the log-density of each variable at each sample is stored (rather than its value).
305+
306+
`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
307+
`:likelihood`.
308+
309+
You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName,
310+
Matrix{Float64}}` instead.
311+
312+
See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref),
313+
[`DynamicPPL.pointwise_prior_logdensities`](@ref).
314+
315+
# Examples
316+
317+
```jldoctest pointwise-logdensities-chains; setup=:(using Distributions)
318+
julia> using MCMCChains
319+
320+
julia> @model function demo(xs, y)
321+
s ~ InverseGamma(2, 3)
322+
m ~ Normal(0, √s)
323+
for i in eachindex(xs)
324+
xs[i] ~ Normal(m, √s)
325+
end
326+
y ~ Normal(m, √s)
327+
end
328+
demo (generic function with 2 methods)
329+
330+
julia> # Example observations.
331+
model = demo([1.0, 2.0, 3.0], [4.0]);
332+
333+
julia> # A chain with 3 iterations.
334+
chain = Chains(
335+
reshape(1.:6., 3, 2),
336+
[:s, :m];
337+
info=(varname_to_symbol=Dict(
338+
@varname(s) => :s,
339+
@varname(m) => :m,
340+
),),
341+
);
342+
343+
julia> plds = pointwise_logdensities(model, chain)
344+
Chains MCMC chain (3×6×1 Array{Float64, 3}):
345+
346+
Iterations = 1:1:3
347+
Number of chains = 1
348+
Samples per chain = 3
349+
parameters = s, m, xs[1], xs[2], xs[3], y
350+
[...]
351+
352+
julia> plds[:s]
353+
2-dimensional AxisArray{Float64,2,...} with axes:
354+
:iter, 1:1:3
355+
:chain, 1:1
356+
And data, a 3×1 Matrix{Float64}:
357+
-0.8027754226637804
358+
-1.3822169643436162
359+
-2.0986122886681096
360+
361+
julia> # The above is the same as:
362+
logpdf.(InverseGamma(2, 3), chain[:s])
363+
3×1 Matrix{Float64}:
364+
-0.8027754226637804
365+
-1.3822169643436162
366+
-2.0986122886681096
367+
```
368+
369+
julia> # Alternatively:
370+
plds_dict = pointwise_logdensities(model, chain, OrderedDict)
371+
OrderedDict{VarName, Matrix{Float64}} with 6 entries:
372+
s => [-0.802775; -1.38222; -2.09861;;]
373+
m => [-8.91894; -7.51551; -7.46824;;]
374+
xs[1] => [-5.41894; -5.26551; -5.63491;;]
375+
xs[2] => [-2.91894; -3.51551; -4.13491;;]
376+
xs[3] => [-1.41894; -2.26551; -2.96824;;]
377+
y => [-0.918939; -1.51551; -2.13491;;]
378+
"""
379+
function DynamicPPL.pointwise_logdensities(
380+
model::DynamicPPL.Model,
381+
chain::MCMCChains.Chains,
382+
::Type{Tout}=MCMCChains.Chains,
383+
::Val{whichlogprob}=Val(:both),
384+
) where {whichlogprob,Tout}
385+
vi = DynamicPPL.VarInfo(model)
386+
acc = DynamicPPL.PointwiseLogProbAccumulator{whichlogprob}()
387+
accname = DynamicPPL.accumulator_name(acc)
388+
vi = DynamicPPL.setaccs!!(vi, (acc,))
389+
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
390+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
391+
pointwise_logps = map(iters) do (sample_idx, chain_idx)
392+
# Extract values from the chain
393+
values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx)
394+
# Re-evaluate the model
395+
_, vi = DynamicPPL.init!!(
396+
model,
397+
vi,
398+
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
399+
)
400+
DynamicPPL.getacc(vi, Val(accname)).logps
401+
end
402+
403+
# pointwise_logps is a matrix of OrderedDicts
404+
all_keys = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
405+
for d in pointwise_logps
406+
union!(all_keys, DynamicPPL.OrderedCollections.OrderedSet(keys(d)))
407+
end
408+
# this is a 3D array: (iterations, variables, chains)
409+
new_data = [
410+
get(pointwise_logps[iter, chain], k, missing) for
411+
iter in 1:size(pointwise_logps, 1), k in all_keys,
412+
chain in 1:size(pointwise_logps, 2)
413+
]
414+
415+
if Tout == MCMCChains.Chains
416+
return MCMCChains.Chains(new_data, Symbol.(collect(all_keys)))
417+
elseif Tout <: AbstractDict
418+
return Tout{DynamicPPL.VarName,Matrix{Float64}}(
419+
k => new_data[:, i, :] for (i, k) in enumerate(all_keys)
420+
)
421+
end
422+
end
423+
424+
"""
425+
DynamicPPL.pointwise_loglikelihoods(
426+
model::DynamicPPL.Model,
427+
chain::MCMCChains.Chains,
428+
::Type{Tout}=MCMCChains.Chains
429+
)
430+
431+
Compute the pointwise log-likelihoods of the model given the chain. This is the same as
432+
`pointwise_logdensities(model, chain)`, but only including the likelihood terms.
433+
434+
See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref).
435+
"""
436+
function DynamicPPL.pointwise_loglikelihoods(
437+
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
438+
) where {Tout}
439+
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:likelihood))
440+
end
441+
442+
"""
443+
DynamicPPL.pointwise_prior_logdensities(
444+
model::DynamicPPL.Model,
445+
chain::MCMCChains.Chains
446+
)
447+
448+
Compute the pointwise log-prior-densities of the model given the chain. This is the same as
449+
`pointwise_logdensities(model, chain)`, but only including the prior terms.
450+
451+
See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref).
452+
"""
453+
function DynamicPPL.pointwise_prior_logdensities(
454+
model::DynamicPPL.Model, chain::MCMCChains.Chains, ::Type{Tout}=MCMCChains.Chains
455+
) where {Tout}
456+
return DynamicPPL.pointwise_logdensities(model, chain, Tout, Val(:prior))
457+
end
458+
459+
"""
460+
logjoint(model::Model, chain::MCMCChains.Chains)
461+
462+
Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`.
463+
464+
# Examples
465+
466+
```jldoctest
467+
julia> using MCMCChains, Distributions
468+
469+
julia> @model function demo_model(x)
470+
s ~ InverseGamma(2, 3)
471+
m ~ Normal(0, sqrt(s))
472+
for i in eachindex(x)
473+
x[i] ~ Normal(m, sqrt(s))
474+
end
475+
end;
476+
477+
julia> # Construct a chain of samples using MCMCChains.
478+
# This sets s = 0.5 and m = 1.0 for all three samples.
479+
chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);
480+
481+
julia> logjoint(demo_model([1., 2.]), chain)
482+
3×1 Matrix{Float64}:
483+
-5.440428709758045
484+
-5.440428709758045
485+
-5.440428709758045
486+
```
487+
"""
488+
function DynamicPPL.logjoint(model::DynamicPPL.Model, chain::MCMCChains.Chains)
489+
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
490+
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
491+
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
492+
vn_parent => DynamicPPL.values_from_chain(
493+
var_info, vn_parent, chain, chain_idx, iteration_idx
494+
) for vn_parent in keys(var_info)
495+
)
496+
DynamicPPL.logjoint(model, argvals_dict)
497+
end
498+
end
499+
500+
"""
501+
loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
502+
503+
Return an array of log likelihoods evaluated at each sample in an MCMC `chain`.
504+
# Examples
505+
506+
```jldoctest
507+
julia> using MCMCChains, Distributions
508+
509+
julia> @model function demo_model(x)
510+
s ~ InverseGamma(2, 3)
511+
m ~ Normal(0, sqrt(s))
512+
for i in eachindex(x)
513+
x[i] ~ Normal(m, sqrt(s))
514+
end
515+
end;
516+
517+
julia> # Construct a chain of samples using MCMCChains.
518+
# This sets s = 0.5 and m = 1.0 for all three samples.
519+
chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);
520+
521+
julia> loglikelihood(demo_model([1., 2.]), chain)
522+
3×1 Matrix{Float64}:
523+
-2.1447298858494
524+
-2.1447298858494
525+
-2.1447298858494
526+
```
527+
"""
528+
function DynamicPPL.loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains)
529+
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
530+
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
531+
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
532+
vn_parent => DynamicPPL.values_from_chain(
533+
var_info, vn_parent, chain, chain_idx, iteration_idx
534+
) for vn_parent in keys(var_info)
535+
)
536+
DynamicPPL.loglikelihood(model, argvals_dict)
537+
end
538+
end
539+
540+
"""
541+
logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
542+
543+
Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`.
544+
545+
# Examples
546+
547+
```jldoctest
548+
julia> using MCMCChains, Distributions
549+
550+
julia> @model function demo_model(x)
551+
s ~ InverseGamma(2, 3)
552+
m ~ Normal(0, sqrt(s))
553+
for i in eachindex(x)
554+
x[i] ~ Normal(m, sqrt(s))
555+
end
556+
end;
557+
558+
julia> # Construct a chain of samples using MCMCChains.
559+
# This sets s = 0.5 and m = 1.0 for all three samples.
560+
chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]);
561+
562+
julia> logprior(demo_model([1., 2.]), chain)
563+
3×1 Matrix{Float64}:
564+
-3.2956988239086447
565+
-3.2956988239086447
566+
-3.2956988239086447
567+
```
568+
"""
569+
function DynamicPPL.logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains)
570+
var_info = DynamicPPL.VarInfo(model) # extract variables info from the model
571+
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
572+
argvals_dict = DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,Any}(
573+
vn_parent => DynamicPPL.values_from_chain(
574+
var_info, vn_parent, chain, chain_idx, iteration_idx
575+
) for vn_parent in keys(var_info)
576+
)
577+
DynamicPPL.logprior(model, argvals_dict)
578+
end
579+
end
580+
295581
end

0 commit comments

Comments
 (0)