@@ -2,7 +2,7 @@ mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model}
22 model:: Tmodel
33 spl:: Tspl
44 vi:: Tvi
5- task :: Task
5+ ctask :: CTask
66
77 function Trace {SampleFromPrior} (model:: Model , spl:: AbstractSampler , vi:: AbstractVarInfo )
88 return new {SampleFromPrior,typeof(vi),typeof(model)} (model, SampleFromPrior (), vi)
1515function Base. copy (trace:: Trace )
1616 vi = deepcopy (trace. vi)
1717 res = Trace {typeof(trace.spl)} (trace. model, trace. spl, vi)
18- res. task = copy (trace. task )
18+ res. ctask = copy (trace. ctask )
1919 return res
2020end
2121
2222# NOTE: this function is called by `forkr`
2323function Trace (f:: Function , m:: Model , spl:: AbstractSampler , vi:: AbstractVarInfo )
2424 res = Trace {typeof(spl)} (m, spl, deepcopy (vi));
25- # CTask(()-> f());
26- res . task = CTask ( () -> begin res = f (); produce (Val{ :done }); res; end )
27- if res . task. storage === nothing
28- res . task. storage = IdDict ()
25+ ctask = CTask (() -> (res = f (); produce (Val{ :done }); res))
26+ task = ctask . task
27+ if task. storage === nothing
28+ task. storage = IdDict ()
2929 end
30- res. task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
30+ task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
31+ res. ctask = ctask
3132 return res
3233end
3334function Trace (m:: Model , spl:: AbstractSampler , vi:: AbstractVarInfo )
3435 res = Trace {typeof(spl)} (m, spl, deepcopy (vi));
35- # CTask(()->f());
3636 reset_num_produce! (res. vi)
37- res. task = CTask ( () -> begin vi_new= m (vi, spl); produce (Val{:done }); vi_new; end )
38- if res. task. storage === nothing
39- res. task. storage = IdDict ()
37+ ctask = CTask (() -> (vi_new = m (vi, spl); produce (Val{:done }); vi_new))
38+ task = ctask. task
39+ if task. storage === nothing
40+ task. storage = IdDict ()
4041 end
41- res. task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
42+ task. storage[:turing_trace ] = res # create a backward reference in task_local_storage
43+ res. ctask = ctask
4244 return res
4345end
4446
4547# step to the next observe statement, return log likelihood
46- Libtask. consume (t:: Trace ) = (increment_num_produce! (t. vi); consume (t. task ))
48+ Libtask. consume (t:: Trace ) = (increment_num_produce! (t. vi); consume (t. ctask ))
4749
4850# Task copying version of fork for Trace.
4951function fork (trace :: Trace , is_ref :: Bool = false )
5052 newtrace = copy (trace)
5153 is_ref && set_retained_vns_del_by_spl! (newtrace. vi, newtrace. spl)
52- newtrace. task. storage[:turing_trace ] = newtrace
54+ newtrace. ctask . task. storage[:turing_trace ] = newtrace
5355 return newtrace
5456end
5557
5658# PG requires keeping all randomness for the reference particle
5759# Create new task and copy randomness
58- function forkr (trace :: Trace )
59- newtrace = Trace (trace. task. code, trace. model, trace. spl, deepcopy (trace. vi))
60+ function forkr (trace:: Trace )
61+ newtrace = Trace (trace. ctask . task. code, trace. model, trace. spl, deepcopy (trace. vi))
6062 newtrace. spl = trace. spl
6163 reset_num_produce! (newtrace. vi)
6264 return newtrace
0 commit comments