Skip to content

Commit 0fa5540

Browse files
authored
fix include_all kwarg for predict, improve perf (#1068)
* Fix `include_all` for predict * Fix include_all for predict, some perf improvements
1 parent ec65b4f commit 0fa5540

File tree

3 files changed

+42
-2
lines changed

3 files changed

+42
-2
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ The `resume_from=chn` keyword argument to `sample` has been removed; please use
6161

6262
**Other changes**
6363

64+
### `predict(model, chain; include_all)`
65+
66+
The `include_all` keyword argument for `predict` now works even when no RNG is specified (previously it would only work when an RNG was explicitly passed).
67+
6468
### `setleafcontext(model, context)`
6569

6670
This convenience method has been added to quickly modify the leaf context of a model.

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,19 @@ function DynamicPPL.predict(
116116
include_all=false,
117117
)
118118
parameter_only_chain = MCMCChains.get_sections(chain, :parameters)
119-
varinfo = DynamicPPL.VarInfo(model)
119+
120+
# Set up a VarInfo with the right accumulators
121+
varinfo = DynamicPPL.setaccs!!(
122+
DynamicPPL.VarInfo(),
123+
(
124+
DynamicPPL.LogPriorAccumulator(),
125+
DynamicPPL.LogJacobianAccumulator(),
126+
DynamicPPL.LogLikelihoodAccumulator(),
127+
DynamicPPL.ValuesAsInModelAccumulator(false),
128+
),
129+
)
130+
_, varinfo = DynamicPPL.init!!(model, varinfo)
131+
varinfo = DynamicPPL.typed_varinfo(varinfo)
120132

121133
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
122134
predictive_samples = map(iters) do (sample_idx, chain_idx)
@@ -129,7 +141,7 @@ function DynamicPPL.predict(
129141
varinfo,
130142
DynamicPPL.InitFromParams(values_dict, DynamicPPL.InitFromPrior()),
131143
)
132-
vals = DynamicPPL.values_as_in_model(model, false, varinfo)
144+
vals = DynamicPPL.getacc(varinfo, Val(:ValuesAsInModel)).values
133145
varname_vals = mapreduce(
134146
collect,
135147
vcat,
@@ -156,6 +168,13 @@ function DynamicPPL.predict(
156168
end
157169
return chain_result[parameter_names]
158170
end
171+
function DynamicPPL.predict(
172+
model::DynamicPPL.Model, chain::MCMCChains.Chains; include_all=false
173+
)
174+
return DynamicPPL.predict(
175+
DynamicPPL.Random.default_rng(), model, chain; include_all=include_all
176+
)
177+
end
159178

160179
function _predictive_samples_to_arrays(predictive_samples)
161180
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()

test/model.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,23 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
519519
@test Set(keys(predictions)) == Set([Symbol("y[1]"), Symbol("y[2]")])
520520
end
521521

522+
@testset "include_all=true" begin
523+
inc_predictions = DynamicPPL.predict(
524+
m_lin_reg_test, β_chain; include_all=true
525+
)
526+
@test Set(keys(inc_predictions)) ==
527+
Set([, Symbol("y[1]"), Symbol("y[2]")])
528+
@test inc_predictions[] == β_chain[]
529+
# check rng is respected
530+
inc_predictions1 = DynamicPPL.predict(
531+
Xoshiro(468), m_lin_reg_test, β_chain; include_all=true
532+
)
533+
inc_predictions2 = DynamicPPL.predict(
534+
Xoshiro(468), m_lin_reg_test, β_chain; include_all=true
535+
)
536+
@test all(Array(inc_predictions1) .== Array(inc_predictions2))
537+
end
538+
522539
@testset "accuracy" begin
523540
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1))
524541
@test ys_pred[1] ground_truth_β * xs_test[1] atol = 0.01

0 commit comments

Comments
 (0)