@@ -292,4 +292,290 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
292292    end 
293293end 
294294
295+ """ 
296+     DynamicPPL.pointwise_logdensities( 
297+         model::DynamicPPL.Model, 
298+         chain::MCMCChains.Chains, 
299+         ::Type{Tout}=MCMCChains.Chains 
300+         ::Val{whichlogprob}=Val(:both), 
301+     ) 
302+ 
303+ Runs `model` on each sample in `chain`, returning a new `MCMCChains.Chains` object where 
304+ the log-density of each variable at each sample is stored (rather than its value). 
305+ 
306+ `whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or 
307+ `:likelihood`. 
308+ 
309+ You can pass `Tout=OrderedDict` to get the result as an `OrderedDict{VarName, 
310+ Matrix{Float64}}` instead. 
311+ 
312+ See also: [`DynamicPPL.pointwise_loglikelihoods`](@ref), 
313+ [`DynamicPPL.pointwise_prior_logdensities`](@ref). 
314+ 
315+ # Examples 
316+ 
317+ ```jldoctest pointwise-logdensities-chains; setup=:(using Distributions) 
318+ julia> using MCMCChains 
319+ 
320+ julia> @model function demo(xs, y) 
321+            s ~ InverseGamma(2, 3) 
322+            m ~ Normal(0, √s) 
323+            for i in eachindex(xs) 
324+                xs[i] ~ Normal(m, √s) 
325+            end 
326+            y ~ Normal(m, √s) 
327+        end 
328+ demo (generic function with 2 methods) 
329+ 
330+ julia> # Example observations. 
331+        model = demo([1.0, 2.0, 3.0], [4.0]); 
332+ 
333+ julia> # A chain with 3 iterations. 
334+        chain = Chains( 
335+            reshape(1.:6., 3, 2), 
336+            [:s, :m]; 
337+            info=(varname_to_symbol=Dict( 
338+                @varname(s) => :s, 
339+                @varname(m) => :m, 
340+            ),), 
341+        ); 
342+ 
343+ julia> plds = pointwise_logdensities(model, chain) 
344+ Chains MCMC chain (3×6×1 Array{Float64, 3}): 
345+ 
346+ Iterations        = 1:1:3 
347+ Number of chains  = 1 
348+ Samples per chain = 3 
349+ parameters        = s, m, xs[1], xs[2], xs[3], y 
350+ [...] 
351+ 
352+ julia> plds[:s] 
353+ 2-dimensional AxisArray{Float64,2,...} with axes: 
354+     :iter, 1:1:3 
355+     :chain, 1:1 
356+ And data, a 3×1 Matrix{Float64}: 
357+  -0.8027754226637804 
358+  -1.3822169643436162 
359+  -2.0986122886681096 
360+ 
361+ julia> # The above is the same as: 
362+        logpdf.(InverseGamma(2, 3), chain[:s]) 
363+ 3×1 Matrix{Float64}: 
364+  -0.8027754226637804 
365+  -1.3822169643436162 
366+  -2.0986122886681096 
367+ ``` 
368+ 
369+ julia> # Alternatively: 
370+        plds_dict = pointwise_logdensities(model, chain, OrderedDict) 
371+ OrderedDict{VarName, Matrix{Float64}} with 6 entries: 
372+   s     => [-0.802775; -1.38222; -2.09861;;] 
373+   m     => [-8.91894; -7.51551; -7.46824;;] 
374+   xs[1] => [-5.41894; -5.26551; -5.63491;;] 
375+   xs[2] => [-2.91894; -3.51551; -4.13491;;] 
376+   xs[3] => [-1.41894; -2.26551; -2.96824;;] 
377+   y     => [-0.918939; -1.51551; -2.13491;;] 
378+ """ 
379+ function  DynamicPPL. pointwise_logdensities (
380+     model:: DynamicPPL.Model ,
381+     chain:: MCMCChains.Chains ,
382+     :: Type{Tout} = MCMCChains. Chains,
383+     :: Val{whichlogprob} = Val (:both ),
384+ ) where  {whichlogprob,Tout}
385+     vi =  DynamicPPL. VarInfo (model)
386+     acc =  DynamicPPL. PointwiseLogProbAccumulator {whichlogprob} ()
387+     accname =  DynamicPPL. accumulator_name (acc)
388+     vi =  DynamicPPL. setaccs!! (vi, (acc,))
389+     parameter_only_chain =  MCMCChains. get_sections (chain, :parameters )
390+     iters =  Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
391+     pointwise_logps =  map (iters) do  (sample_idx, chain_idx)
392+         #  Extract values from the chain
393+         values_dict =  chain_sample_to_varname_dict (parameter_only_chain, sample_idx, chain_idx)
394+         #  Re-evaluate the model
395+         _, vi =  DynamicPPL. init!! (
396+             model,
397+             vi,
398+             DynamicPPL. InitFromParams (values_dict, DynamicPPL. InitFromPrior ()),
399+         )
400+         DynamicPPL. getacc (vi, Val (accname)). logps
401+     end 
402+ 
403+     #  pointwise_logps is a matrix of OrderedDicts
404+     all_keys =  DynamicPPL. OrderedCollections. OrderedSet {DynamicPPL.VarName} ()
405+     for  d in  pointwise_logps
406+         union! (all_keys, DynamicPPL. OrderedCollections. OrderedSet (keys (d)))
407+     end 
408+     #  this is a 3D array: (iterations, variables, chains)
409+     new_data =  [
410+         get (pointwise_logps[iter, chain], k, missing ) for 
411+         iter in  1 : size (pointwise_logps, 1 ), k in  all_keys,
412+         chain in  1 : size (pointwise_logps, 2 )
413+     ]
414+ 
415+     if  Tout ==  MCMCChains. Chains
416+         return  MCMCChains. Chains (new_data, Symbol .(collect (all_keys)))
417+     elseif  Tout <:  AbstractDict 
418+         return  Tout {DynamicPPL.VarName,Matrix{Float64}} (
419+             k =>  new_data[:, i, :] for  (i, k) in  enumerate (all_keys)
420+         )
421+     end 
422+ end 
423+ 
424+ """ 
425+     DynamicPPL.pointwise_loglikelihoods( 
426+         model::DynamicPPL.Model, 
427+         chain::MCMCChains.Chains, 
428+         ::Type{Tout}=MCMCChains.Chains 
429+     ) 
430+ 
431+ Compute the pointwise log-likelihoods of the model given the chain. This is the same as 
432+ `pointwise_logdensities(model, chain)`, but only including the likelihood terms. 
433+ 
434+ See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_prior_logdensities`](@ref). 
435+ """ 
436+ function  DynamicPPL. pointwise_loglikelihoods (
437+     model:: DynamicPPL.Model , chain:: MCMCChains.Chains , :: Type{Tout} = MCMCChains. Chains
438+ ) where  {Tout}
439+     return  DynamicPPL. pointwise_logdensities (model, chain, Tout, Val (:likelihood ))
440+ end 
441+ 
442+ """ 
443+     DynamicPPL.pointwise_prior_logdensities( 
444+         model::DynamicPPL.Model, 
445+         chain::MCMCChains.Chains 
446+     ) 
447+ 
448+ Compute the pointwise log-prior-densities of the model given the chain. This is the same as 
449+ `pointwise_logdensities(model, chain)`, but only including the prior terms. 
450+ 
451+ See also: [`DynamicPPL.pointwise_logdensities`](@ref), [`DynamicPPL.pointwise_loglikelihoods`](@ref). 
452+ """ 
453+ function  DynamicPPL. pointwise_prior_logdensities (
454+     model:: DynamicPPL.Model , chain:: MCMCChains.Chains , :: Type{Tout} = MCMCChains. Chains
455+ ) where  {Tout}
456+     return  DynamicPPL. pointwise_logdensities (model, chain, Tout, Val (:prior ))
457+ end 
458+ 
459+ """ 
460+     logjoint(model::Model, chain::MCMCChains.Chains) 
461+ 
462+ Return an array of log joint probabilities evaluated at each sample in an MCMC `chain`. 
463+ 
464+ # Examples 
465+ 
466+ ```jldoctest 
467+ julia> using MCMCChains, Distributions 
468+ 
469+ julia> @model function demo_model(x) 
470+            s ~ InverseGamma(2, 3) 
471+            m ~ Normal(0, sqrt(s)) 
472+            for i in eachindex(x) 
473+                x[i] ~ Normal(m, sqrt(s)) 
474+            end 
475+        end; 
476+ 
477+ julia> # Construct a chain of samples using MCMCChains. 
478+        # This sets s = 0.5 and m = 1.0 for all three samples. 
479+        chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); 
480+ 
481+ julia> logjoint(demo_model([1., 2.]), chain) 
482+ 3×1 Matrix{Float64}: 
483+  -5.440428709758045 
484+  -5.440428709758045 
485+  -5.440428709758045 
486+ ``` 
487+ """ 
488+ function  DynamicPPL. logjoint (model:: DynamicPPL.Model , chain:: MCMCChains.Chains )
489+     var_info =  DynamicPPL. VarInfo (model) #  extract variables info from the model
490+     map (Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))) do  (iteration_idx, chain_idx)
491+         argvals_dict =  DynamicPPL. OrderedCollections. OrderedDict {DynamicPPL.VarName,Any} (
492+             vn_parent =>  DynamicPPL. values_from_chain (
493+                 var_info, vn_parent, chain, chain_idx, iteration_idx
494+             ) for  vn_parent in  keys (var_info)
495+         )
496+         DynamicPPL. logjoint (model, argvals_dict)
497+     end 
498+ end 
499+ 
500+ """ 
501+     loglikelihood(model::DynamicPPL.Model, chain::MCMCChains.Chains) 
502+ 
503+ Return an array of log likelihoods evaluated at each sample in an MCMC `chain`. 
504+ # Examples 
505+ 
506+ ```jldoctest 
507+ julia> using MCMCChains, Distributions 
508+ 
509+ julia> @model function demo_model(x) 
510+            s ~ InverseGamma(2, 3) 
511+            m ~ Normal(0, sqrt(s)) 
512+            for i in eachindex(x) 
513+                x[i] ~ Normal(m, sqrt(s)) 
514+            end 
515+        end; 
516+ 
517+ julia> # Construct a chain of samples using MCMCChains. 
518+        # This sets s = 0.5 and m = 1.0 for all three samples. 
519+        chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); 
520+ 
521+ julia> loglikelihood(demo_model([1., 2.]), chain) 
522+ 3×1 Matrix{Float64}: 
523+  -2.1447298858494 
524+  -2.1447298858494 
525+  -2.1447298858494 
526+ ``` 
527+ """ 
528+ function  DynamicPPL. loglikelihood (model:: DynamicPPL.Model , chain:: MCMCChains.Chains )
529+     var_info =  DynamicPPL. VarInfo (model) #  extract variables info from the model
530+     map (Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))) do  (iteration_idx, chain_idx)
531+         argvals_dict =  DynamicPPL. OrderedCollections. OrderedDict {DynamicPPL.VarName,Any} (
532+             vn_parent =>  DynamicPPL. values_from_chain (
533+                 var_info, vn_parent, chain, chain_idx, iteration_idx
534+             ) for  vn_parent in  keys (var_info)
535+         )
536+         DynamicPPL. loglikelihood (model, argvals_dict)
537+     end 
538+ end 
539+ 
540+ """ 
541+     logprior(model::DynamicPPL.Model, chain::MCMCChains.Chains) 
542+ 
543+ Return an array of log prior probabilities evaluated at each sample in an MCMC `chain`. 
544+ 
545+ # Examples 
546+ 
547+ ```jldoctest 
548+ julia> using MCMCChains, Distributions 
549+ 
550+ julia> @model function demo_model(x) 
551+            s ~ InverseGamma(2, 3) 
552+            m ~ Normal(0, sqrt(s)) 
553+            for i in eachindex(x) 
554+                x[i] ~ Normal(m, sqrt(s)) 
555+            end 
556+        end; 
557+ 
558+ julia> # Construct a chain of samples using MCMCChains. 
559+        # This sets s = 0.5 and m = 1.0 for all three samples. 
560+        chain = Chains(repeat([0.5 1.0;;;], 3, 1, 1), [:s, :m]); 
561+ 
562+ julia> logprior(demo_model([1., 2.]), chain) 
563+ 3×1 Matrix{Float64}: 
564+  -3.2956988239086447 
565+  -3.2956988239086447 
566+  -3.2956988239086447 
567+ ``` 
568+ """ 
569+ function  DynamicPPL. logprior (model:: DynamicPPL.Model , chain:: MCMCChains.Chains )
570+     var_info =  DynamicPPL. VarInfo (model) #  extract variables info from the model
571+     map (Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))) do  (iteration_idx, chain_idx)
572+         argvals_dict =  DynamicPPL. OrderedCollections. OrderedDict {DynamicPPL.VarName,Any} (
573+             vn_parent =>  DynamicPPL. values_from_chain (
574+                 var_info, vn_parent, chain, chain_idx, iteration_idx
575+             ) for  vn_parent in  keys (var_info)
576+         )
577+         DynamicPPL. logprior (model, argvals_dict)
578+     end 
579+ end 
580+ 
295581end 
0 commit comments