@@ -11,6 +11,12 @@ Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for
1111resampling. 
1212""" 
1313function  set_all_del! (vi:: AbstractVarInfo )
14+     #  TODO (penelopeysm): Instead of being a 'del' flag on the VarInfo, we
15+     #  could either:
16+     #  - keep a boolean 'resample' flag on the trace, or
17+     #  - modify the model context appropriately.
18+     #  However, this refactoring will have to wait until InitContext is
19+     #  merged into DPPL.
1420    for  vn in  keys (vi)
1521        DynamicPPL. set_flag! (vi, vn, " del"  )
1622    end 
5965function  AdvancedPS. advance! (
6066    trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} , isref:: Bool = false 
6167)
62-     #  We want to increment num produce for the VarInfo stored in the trace. The trace is
63-     #  mutable, so we create a new model with the incremented VarInfo and set it in the trace
64-     model =  trace. model
65-     model =  Accessors. @set  model. f. varinfo =  DynamicPPL. increment_num_produce!! (
66-         model. f. varinfo
67-     )
68-     trace. model =  model
6968    #  Make sure we load/reset the rng in the new replaying mechanism
7069    isref ?  AdvancedPS. load_state! (trace. rng) :  AdvancedPS. save_state! (trace. rng)
7170    score =  consume (trace. model. ctask)
7271    return  score
7372end 
7473
7574function  AdvancedPS. delete_retained! (trace:: TracedModel )
76-     #  TODO (DPPL0.37/penelopeysm): Explain this a bit better.
77-     # 
7875    #  This method is called if, during a CSMC update, we perform a resampling
7976    #  and choose the reference particle as the trajectory to carry on from.
8077    #  In such a case, we need to ensure that when we continue sampling (i.e.
8178    #  the next time we hit tilde_assume), we don't use the values in the 
8279    #  reference particle but rather sample new values.
83-     #  In this implementation, we indiscriminately set the 'del' flag for all
84-     #  variables in the VarInfo. This is slightly overkill: it is not necessary
85-     #  to set the 'del' flag for variables that were already sampled. However,
86-     #  it allows us to avoid using DynamicPPL.set_retained_vns_del!.
80+     # 
81+     #  Here, we indiscriminately set the 'del' flag for all variables in the
82+     #  VarInfo. This is slightly overkill: it is not necessary to set the 'del'
83+     #  flag for variables that were already sampled. However, it allows us to
84+     #  avoid keeping track of which variables were sampled, which leads to many
85+     #  simplifications in the VarInfo data structure.
8786    set_all_del! (trace. varinfo)
8887    return  trace
8988end 
9089
9190function  AdvancedPS. reset_model (trace:: TracedModel )
92-     return  Accessors . @set   trace. varinfo  =  DynamicPPL . reset_num_produce!! (trace . varinfo) 
91+     return  trace
9392end 
9493
9594function  Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
@@ -213,7 +212,6 @@ function DynamicPPL.initialstep(
213212)
214213    #  Reset the VarInfo.
215214    vi =  DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
216-     vi =  DynamicPPL. reset_num_produce!! (vi)
217215    set_all_del! (vi)
218216    vi =  DynamicPPL. resetlogp!! (vi)
219217    vi =  DynamicPPL. empty!! (vi)
@@ -344,7 +342,6 @@ function DynamicPPL.initialstep(
344342)
345343    vi =  DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
346344    #  Reset the VarInfo before new sweep
347-     vi =  DynamicPPL. reset_num_produce!! (vi)
348345    set_all_del! (vi)
349346    vi =  DynamicPPL. resetlogp!! (vi)
350347
@@ -366,11 +363,6 @@ function DynamicPPL.initialstep(
366363
367364    #  Compute the first transition.
368365    _vi =  reference. model. f. varinfo
369-     #  Unset any 'del' flags before we actually construct the transition.
370-     #  This is necessary because the model will be re-evaluated and we
371-     #  want to make sure we do use the values in the reference particle
372-     #  instead of resampling them.
373-     unset_all_del! (_vi)
374366    transition =  PGTransition (model, _vi, logevidence)
375367
376368    return  transition, PGState (_vi, reference. rng)
@@ -382,10 +374,10 @@ function AbstractMCMC.step(
382374    #  Reset the VarInfo before new sweep.
383375    vi =  state. vi
384376    vi =  DynamicPPL. setacc!! (vi, ProduceLogLikelihoodAccumulator ())
385-     vi =  DynamicPPL. reset_num_produce!! (vi)
386377    vi =  DynamicPPL. resetlogp!! (vi)
387378
388379    #  Create reference particle for which the samples will be retained.
380+     unset_all_del! (vi)
389381    reference =  AdvancedPS. forkr (AdvancedPS. Trace (model, spl, vi, state. rng))
390382
391383    #  For all other particles, do not retain the variables but resample them.
@@ -412,11 +404,6 @@ function AbstractMCMC.step(
412404
413405    #  Compute the transition.
414406    _vi =  newreference. model. f. varinfo
415-     #  Unset any 'del' flags before we actually construct the transition.
416-     #  This is necessary because the model will be re-evaluated and we
417-     #  want to make sure we do use the values in the reference particle
418-     #  instead of resampling them.
419-     unset_all_del! (_vi)
420407    transition =  PGTransition (model, _vi, logevidence)
421408
422409    return  transition, PGState (_vi, newreference. rng)
@@ -499,12 +486,11 @@ function DynamicPPL.assume(
499486        vi =  push!! (vi, vn, r, dist)
500487    elseif  DynamicPPL. is_flagged (vi, vn, " del"  )
501488        DynamicPPL. unset_flag! (vi, vn, " del"  ) #  Reference particle parent
502-         r =  rand (trng, dist)
503-         vi[vn] =  DynamicPPL. tovec (r)
504489        #  TODO (mhauru):
505490        #  The below is the only line that differs from assume called on SampleFromPrior.
506-         #  Could we just call assume on SampleFromPrior and then `setorder!!` after that?
507-         vi =  DynamicPPL. setorder!! (vi, vn, DynamicPPL. get_num_produce (vi))
491+         #  Could we just call assume on SampleFromPrior with a specific rng?
492+         r =  rand (trng, dist)
493+         vi[vn] =  DynamicPPL. tovec (r)
508494    else 
509495        r =  vi[vn]
510496    end 
@@ -546,8 +532,6 @@ function AdvancedPS.Trace(
546532    rng:: AdvancedPS.TracedRNG ,
547533)
548534    newvarinfo =  deepcopy (varinfo)
549-     newvarinfo =  DynamicPPL. reset_num_produce!! (newvarinfo)
550- 
551535    tmodel =  TracedModel (model, sampler, newvarinfo, rng)
552536    newtrace =  AdvancedPS. Trace (tmodel, rng)
553537    return  newtrace
0 commit comments