-
Notifications
You must be signed in to change notification settings - Fork 37
Description
Back in the day, the evaluator for a @model would look like
function demo(rng, model, varinfo, context, sampler)
...
endor something like this.
But when we started making use of contexts more drastically in #249 , this became instead the simpler
function demo(model, varinfo, context)
...
endand instead we provide the rng and sampler argument sometimes through the Samplingcontext(rng, sampler, context), thus providing a clear separation between when we're sampling and when we're evaluating a model.
However, this consequently doesn't allow someone to define models with inherit randomness in them while still preserving determinisim conditional on a rng.
A simple example is an implementation of a model with subsampling of the data. I could do this as follows:
julia> using DynamicPPL, Distributions
julia> @model function demo(y, batch_size=1)
x ~ Normal()
y_indices = rand(1:length(y), batch_size)
y[y_indices] .~ Normal(x, 1)
end
demo (generic function with 4 methods)
julia> model = demo(randn(32), 4);
julia> model()
4-element view(::Vector{Float64}, [11, 18, 20, 5]) with eltype Float64:
-0.35070761628093716
0.19218100395285695
0.8506289607980133
1.0317998072662038
julia> rand(model)
(x = -0.991550644289528,)However, the issue in the above is ofc that the rand call inside @model doesn't have access to the rng used internally in the model, and thus cannot ensure that everything is deterministic given an rng.
Ofc, the user could provide the rng as a specific argument, but this seems quite redundant as we often will have an rng available.
As a result, I'm thinking that it might be useful to remove the rng from the SamplingConctext and instead make it a similarly "private" varriable so users could do
julia> @model function demo(y, batch_size=1)
x ~ Normal()
y_indices = rand(__rng__, 1:length(y), batch_size)
y[y_indices] .~ Normal(x, 1)
end
demo (generic function with 4 methods)as they could before.
Thoughts? @mhauru @penelopeysm @yebai @willtebbutt @sunxd3