Skip to content

Commit 57b338b

Browse files
committed
Revert "Add to_chains and from_chains function (#1087)"
This reverts commit 11b7e01.
1 parent 11b7e01 commit 57b338b

File tree

10 files changed

+86
-459
lines changed

10 files changed

+86
-459
lines changed

HISTORY.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
# DynamicPPL Changelog
22

3-
## 0.38.3
4-
5-
Added a new exported struct, `DynamicPPL.ParamsWithStats`, and a corresponding function `DynamicPPL.to_chains`, which automatically converts a collection of `ParamsWithStats` to a given Chains type.
6-
73
## 0.38.2
84

95
Added a compatibility entry for JET@0.11.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.3"
3+
version = "0.38.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -505,21 +505,3 @@ There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_va
505505
DynamicPPL.Experimental.determine_suitable_varinfo
506506
DynamicPPL.Experimental.is_suitable_varinfo
507507
```
508-
509-
### Converting VarInfos to chains
510-
511-
It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis.
512-
This can be accomplished with the following:
513-
514-
```@docs
515-
DynamicPPL.ParamsWithStats
516-
DynamicPPL.to_chains
517-
```
518-
519-
Furthermore, one can convert chains back into a collection of parameter dictionaries and/or stats with:
520-
521-
```@docs
522-
DynamicPPL.from_chains
523-
```
524-
525-
This is useful if you want to use the result of a chain in further model evaluations.

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 56 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -36,113 +36,6 @@ function chain_sample_to_varname_dict(
3636
return d
3737
end
3838

39-
"""
40-
DynamicPPL.to_chains(
41-
::Type{MCMCChains.Chains},
42-
params_and_stats::AbstractArray{<:ParamsWithStats}
43-
)
44-
45-
Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
46-
"""
47-
function DynamicPPL.to_chains(
48-
::Type{MCMCChains.Chains},
49-
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats},
50-
)
51-
# Handle parameters
52-
all_vn_leaves = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
53-
split_dicts = map(params_and_stats) do ps
54-
# Separate into individual VarNames.
55-
vn_leaves_and_vals = if isempty(ps.params)
56-
Tuple{DynamicPPL.VarName,Any}[]
57-
else
58-
iters = map(
59-
AbstractPPL.varname_and_value_leaves,
60-
keys(ps.params),
61-
values(ps.params),
62-
)
63-
mapreduce(collect, vcat, iters)
64-
end
65-
vn_leaves = map(first, vn_leaves_and_vals)
66-
vals = map(last, vn_leaves_and_vals)
67-
for vn_leaf in vn_leaves
68-
push!(all_vn_leaves, vn_leaf)
69-
end
70-
DynamicPPL.OrderedCollections.OrderedDict(zip(vn_leaves, vals))
71-
end
72-
vn_leaves = collect(all_vn_leaves)
73-
param_vals = [
74-
get(split_dicts[i, j], key, missing) for i in eachindex(axes(split_dicts, 1)),
75-
key in vn_leaves, j in eachindex(axes(split_dicts, 2))
76-
]
77-
param_symbols = map(Symbol, vn_leaves)
78-
# Handle statistics
79-
stat_keys = DynamicPPL.OrderedCollections.OrderedSet{Symbol}()
80-
for ps in params_and_stats
81-
for k in keys(ps.stats)
82-
push!(stat_keys, k)
83-
end
84-
end
85-
stat_keys = collect(stat_keys)
86-
stat_vals = [
87-
get(params_and_stats[i, j].stats, key, missing) for
88-
i in eachindex(axes(params_and_stats, 1)), key in stat_keys,
89-
j in eachindex(axes(params_and_stats, 2))
90-
]
91-
# Construct name map and info
92-
name_map = (internals=stat_keys,)
93-
info = (
94-
varname_to_symbol=DynamicPPL.OrderedCollections.OrderedDict(
95-
zip(all_vn_leaves, param_symbols)
96-
),
97-
)
98-
# Concatenate parameter and statistic values
99-
vals = cat(param_vals, stat_vals; dims=2)
100-
symbols = vcat(param_symbols, stat_keys)
101-
return MCMCChains.Chains(MCMCChains.concretize(vals), symbols, name_map; info=info)
102-
end
103-
function DynamicPPL.to_chains(
104-
::Type{MCMCChains.Chains}, ps::AbstractVector{<:DynamicPPL.ParamsWithStats}
105-
)
106-
return DynamicPPL.to_chains(MCMCChains.Chains, hcat(ps))
107-
end
108-
109-
function DynamicPPL.from_chains(
110-
::Type{T}, chain::MCMCChains.Chains
111-
) where {T<:AbstractDict{<:DynamicPPL.VarName}}
112-
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
113-
matrix = map(idxs) do (sample_idx, chain_idx)
114-
d = T()
115-
for vn in DynamicPPL.varnames(chain)
116-
d[vn] = DynamicPPL.getindex_varname(chain, sample_idx, vn, chain_idx)
117-
end
118-
d
119-
end
120-
return matrix
121-
end
122-
function DynamicPPL.from_chains(::Type{NamedTuple}, chain::MCMCChains.Chains)
123-
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
124-
matrix = map(idxs) do (sample_idx, chain_idx)
125-
get(chain[sample_idx, :, chain_idx], keys(chain); flatten=true)
126-
end
127-
return matrix
128-
end
129-
function DynamicPPL.from_chains(
130-
::Type{DynamicPPL.ParamsWithStats}, chain::MCMCChains.Chains
131-
)
132-
idxs = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
133-
internals_chain = MCMCChains.get_sections(chain, :internals)
134-
params = DynamicPPL.from_chains(
135-
DynamicPPL.OrderedCollections.OrderedDict{DynamicPPL.VarName,eltype(chain.value)},
136-
chain,
137-
)
138-
stats = DynamicPPL.from_chains(NamedTuple, internals_chain)
139-
return map(idxs) do (sample_idx, chain_idx)
140-
DynamicPPL.ParamsWithStats(
141-
params[sample_idx, chain_idx], stats[sample_idx, chain_idx]
142-
)
143-
end
144-
end
145-
14639
"""
14740
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)
14841
@@ -217,6 +110,7 @@ function DynamicPPL.predict(
217110
DynamicPPL.VarInfo(),
218111
(
219112
DynamicPPL.LogPriorAccumulator(),
113+
DynamicPPL.LogJacobianAccumulator(),
220114
DynamicPPL.LogLikelihoodAccumulator(),
221115
DynamicPPL.ValuesAsInModelAccumulator(false),
222116
),
@@ -235,9 +129,23 @@ function DynamicPPL.predict(
235129
varinfo,
236130
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
237131
)
238-
DynamicPPL.ParamsWithStats(varinfo, nothing)
132+
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
133+
varname_vals = mapreduce(
134+
collect,
135+
vcat,
136+
map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)),
137+
)
138+
139+
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
239140
end
240-
chain_result = DynamicPPL.to_chains(MCMCChains.Chains, predictive_samples)
141+
142+
chain_result = reduce(
143+
MCMCChains.chainscat,
144+
[
145+
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for
146+
chain_idx in 1:size(predictive_samples, 2)
147+
],
148+
)
241149
parameter_names = if include_all
242150
MCMCChains.names(chain_result, :parameters)
243151
else
@@ -256,6 +164,45 @@ function DynamicPPL.predict(
256164
)
257165
end
258166

167+
function _predictive_samples_to_arrays(predictive_samples)
168+
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
169+
170+
sample_dicts = map(predictive_samples) do sample
171+
varname_value_pairs = sample.varname_and_values
172+
varnames = map(first, varname_value_pairs)
173+
values = map(last, varname_value_pairs)
174+
for varname in varnames
175+
push!(variable_names_set, varname)
176+
end
177+
178+
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
179+
end
180+
181+
variable_names = collect(variable_names_set)
182+
variable_values = [
183+
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
184+
key in variable_names
185+
]
186+
187+
return variable_names, variable_values
188+
end
189+
190+
function _predictive_samples_to_chains(predictive_samples)
191+
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples)
192+
variable_names_symbols = map(Symbol, variable_names)
193+
194+
internal_parameters = [:lp]
195+
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1)
196+
197+
parameter_names = [variable_names_symbols; internal_parameters]
198+
parameter_values = hcat(variable_values, log_probabilities)
199+
parameter_values = MCMCChains.concretize(parameter_values)
200+
201+
return MCMCChains.Chains(
202+
parameter_values, parameter_names, (internals=internal_parameters,)
203+
)
204+
end
205+
259206
"""
260207
returned(model::Model, chain::MCMCChains.Chains)
261208

src/DynamicPPL.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ export AbstractVarInfo,
126126
prefix,
127127
returned,
128128
to_submodel,
129-
# Chain construction
130-
ParamsWithStats,
131-
to_chains,
132129
# Convenience macros
133130
@addlogprob!,
134131
value_iterator_from_chain,
@@ -197,7 +194,6 @@ include("model_utils.jl")
197194
include("extract_priors.jl")
198195
include("values_as_in_model.jl")
199196
include("bijector.jl")
200-
include("to_chains.jl")
201197

202198
include("debug_utils.jl")
203199
using .DebugUtils

0 commit comments

Comments
 (0)