Skip to content

Should we always pass rng to the model? #721

@torfjelde

Description

@torfjelde

Back in the day, the evaluator for a @model would look like

function demo(rng, model, varinfo, context, sampler)
    ...
end

or something like this.
But when we started making use of contexts more drastically in #249 , this became instead the simpler

function demo(model, varinfo, context)
    ...
end

and instead we provide the rng and sampler argument sometimes through the Samplingcontext(rng, sampler, context), thus providing a clear separation between when we're sampling and when we're evaluating a model.

However, this consequently doesn't allow someone to define models with inherit randomness in them while still preserving determinisim conditional on a rng.

A simple example is an implementation of a model with subsampling of the data. I could do this as follows:

julia> using DynamicPPL, Distributions

julia> @model function demo(y, batch_size=1)
           x ~ Normal()
           y_indices = rand(1:length(y), batch_size)
           y[y_indices] .~ Normal(x, 1)
       end
demo (generic function with 4 methods)

julia> model = demo(randn(32), 4);

julia> model()
4-element view(::Vector{Float64}, [11, 18, 20, 5]) with eltype Float64:
 -0.35070761628093716
  0.19218100395285695
  0.8506289607980133
  1.0317998072662038

julia> rand(model)
(x = -0.991550644289528,)

However, the issue in the above is ofc that the rand call inside @model doesn't have access to the rng used internally in the model, and thus cannot ensure that everything is deterministic given an rng.

Ofc, the user could provide the rng as a specific argument, but this seems quite redundant as we often will have an rng available.

As a result, I'm thinking that it might be useful to remove the rng from the SamplingConctext and instead make it a similarly "private" varriable so users could do

julia> @model function demo(y, batch_size=1)
           x ~ Normal()
           y_indices = rand(__rng__, 1:length(y), batch_size)
           y[y_indices] .~ Normal(x, 1)
       end
demo (generic function with 4 methods)

as they could before.

Thoughts? @mhauru @penelopeysm @yebai @willtebbutt @sunxd3

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions