-
Notifications
You must be signed in to change notification settings - Fork 37
Description
On Julia Slack this issue came up again about only storing the values of specific random variables (rather than every random variable): https://julialang.slack.com/archives/CCYDC34A0/p1742032177765899
See also:
- Specify which variables to track Turing.jl#1444 basically the same thing
:=to keep track of generated quantities #594 which is a way to do the opposite, i.e. store extra quantities that aren't random variables
The solution I initially suggested was to get the raw Turing.Inference.Transitions from the call to AbstractMCMC.sample, then subset those appropriately before constructing the chain. Note Tor's comment here TuringLang/Turing.jl#1444 (comment) basically does that as well.
Having thought about it a bit, though, it seems to me that now that we have values_as_in_model we could use it as a 'hook' for the user to declare which variables they're interested in.
Specifically:
DynamicPPL.jl/src/values_as_in_model.jl
Lines 22 to 29 in d1a98c6
| struct ValuesAsInModelContext{C<:AbstractContext} <: AbstractContext | |
| "values that are extracted from the model" | |
| values::OrderedDict | |
| "whether to extract variables on the LHS of :=" | |
| include_colon_eq::Bool | |
| "child context" | |
| context::C | |
| end |
^ Add a field to ValuesAsInModelContext to denote the varnames that we want to track, which can maybe default to nothing to indicate that all varnames should be tracked
DynamicPPL.jl/src/values_as_in_model.jl
Lines 72 to 73 in d1a98c6
| # Save the value. | |
| push!(context, vn, value) |
^ Do a check here to see if vn is subsumed by any of the varnames we want to track, and skip the push! if it isn't
DynamicPPL.jl/src/values_as_in_model.jl
Lines 166 to 175 in d1a98c6
| function values_as_in_model( | |
| model::Model, | |
| include_colon_eq::Bool, | |
| varinfo::AbstractVarInfo, | |
| context::AbstractContext=DefaultContext(), | |
| ) | |
| context = ValuesAsInModelContext(include_colon_eq, context) | |
| evaluate!!(model, varinfo, context) | |
| return context.values | |
| end |
^ Add a new argument to this function, tracked_varnames, with a default value of tracked_varnames(model).
Finally to preserve the default behaviour, we implement:
tracked_varnames(::Model) = nothingAnd then a user can just override tracked_varnames for their own model.
@model mymodel() = ...
DynamicPPL.tracked_varnames(::Model{typeof(mymodel)}) = [@varname(x), @varname(y), ...]