-
Notifications
You must be signed in to change notification settings - Fork 37
Closed as not planned
Description
#984 uses init!! to implement predict. However, the implementation of include_all=false seems a bit wasteful because it first constructs a chain using all parameters (including the ones we don't want) before then subsetting the chain. It seems more sensible to, inside the loop, filter the dictionary of varname => value pairs in each iteration so that those variables don't end up in the chain to begin with.
DynamicPPL.jl/ext/DynamicPPLMCMCChainsExt.jl
Lines 122 to 158 in 956ed54
| predictive_samples = map(iters) do (sample_idx, chain_idx) | |
| # Extract values from the chain | |
| values_dict = chain_sample_to_varname_dict(parameter_only_chain, sample_idx, chain_idx) | |
| # Resample any variables that are not present in `values_dict` | |
| _, varinfo = DynamicPPL.init!!( | |
| rng, | |
| model, | |
| varinfo, | |
| DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()), | |
| ) | |
| vals = DynamicPPL.values_as_in_model(model, false, varinfo) | |
| varname_vals = mapreduce( | |
| collect, | |
| vcat, | |
| map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals)), | |
| ) | |
| return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo)) | |
| end | |
| chain_result = reduce( | |
| MCMCChains.chainscat, | |
| [ | |
| _predictive_samples_to_chains(predictive_samples[:, chain_idx]) for | |
| chain_idx in 1:size(predictive_samples, 2) | |
| ], | |
| ) | |
| parameter_names = if include_all | |
| MCMCChains.names(chain_result, :parameters) | |
| else | |
| filter( | |
| k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)), | |
| names(chain_result, :parameters), | |
| ) | |
| end | |
| return chain_result[parameter_names] | |
| end |
Not making this change in #984 to avoid complicating matters.
Metadata
Metadata
Assignees
Labels
No labels