@@ -11,59 +11,42 @@ function _error_msg()
1111 return " This macro is only for use in the `@model` macro and not for external use."
1212end
1313
14-
15-
16- # Check if the right-hand side is a distribution.
17- function assert_dist (dist; msg)
18- isa (dist, Distribution) || throw (ArgumentError (msg))
19- end
20- function assert_dist (dist:: AbstractVector ; msg)
21- all (d -> isa (d, Distribution), dist) || throw (ArgumentError (msg))
22- end
23-
24- function wrong_dist_errormsg (l)
25- return " Right-hand side of a ~ must be subtype of Distribution or a vector of " *
26- " Distributions on line $(l) ."
27- end
14+ const DISTMSG = " Right-hand side of a ~ must be subtype of Distribution or a vector of " *
15+ " Distributions."
2816
2917"""
30- @isassumption(model, expr)
18+ isassumption(model, expr)
19+
20+ Return an expression that can be evaluated to check if `expr` is an assumption in the
21+ `model`.
3122
32- Let `expr` be `x[1]`. `vn` is an assumption in the following cases:
33- 1. `x` was not among the input data to the model,
34- 2. `x` was among the input data to the model but with a value `missing`, or
35- 3. `x` was among the input data to the model with a value other than missing,
23+ Let `expr` be `:( x[1]) `. It is an assumption in the following cases:
24+ 1. `x` is not among the input data to the ` model` ,
25+ 2. `x` is among the input data to the ` model` but with a value `missing`, or
26+ 3. `x` is among the input data to the ` model` with a value other than missing,
3627 but `x[1] === missing`.
28+
3729When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
3830"""
39- macro isassumption (model, expr:: Union{Symbol, Expr} )
40- # Note: never put a return in this... don't forget it's a macro!
31+ function isassumption (model, expr:: Union{Symbol, Expr} )
4132 vn = gensym (:vn )
42-
33+
4334 return quote
44- $ vn = @varname ($ expr)
45-
46- # This branch should compile nicely in all cases except for partial missing data
47- # For example, when `expr` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
48- if ! DynamicPPL. inargnames ($ vn, $ model) || DynamicPPL. inmissings ($ vn, $ model)
49- true
50- else
51- if DynamicPPL. inargnames ($ vn, $ model)
52- # Evaluate the lhs
53- $ expr === missing
35+ let $ vn = $ (varname (expr))
36+ # This branch should compile nicely in all cases except for partial missing data
37+ # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
38+ if ! $ (DynamicPPL. inargnames)($ vn, $ model) || $ (DynamicPPL. inmissings)($ vn, $ model)
39+ true
5440 else
55- throw (" This point should not be reached. Please report this error." )
41+ # Evaluate the LHS
42+ $ expr === missing
5643 end
5744 end
58- end |> esc
59- end
60-
61- macro isassumption (model, expr)
62- # failsafe: a literal is never an assumption
63- false
45+ end
6446end
6547
66-
48+ # failsafe: a literal is never an assumption
49+ isassumption (model, expr) = :(false )
6750
6851# ################
6952# Main Compiler #
@@ -128,7 +111,7 @@ function build_model_info(input_expr)
128111 Expr (:tuple , QuoteNode .(arg_syms)... ),
129112 Expr (:curly , :Tuple , [:(Core. Typeof ($ x)) for x in arg_syms]. .. )
130113 )
131- args_nt = Expr (:call , :(DynamicPPL . namedtuple), nt_type, Expr (:tuple , arg_syms... ))
114+ args_nt = Expr (:call , :($ namedtuple), nt_type, Expr (:tuple , arg_syms... ))
132115 end
133116 args = map (modeldef[:args ]) do arg
134117 if (arg isa Symbol)
@@ -217,7 +200,7 @@ function replace_logpdf!(model_info)
217200 vi = model_info[:main_body_names ][:vi ]
218201 ex = MacroTools. postwalk (ex) do x
219202 if @capture (x, @logpdf ())
220- :($ vi . logp[])
203+ :($ (vi) . logp[])
221204 else
222205 x
223206 end
@@ -261,14 +244,14 @@ function replace_tilde!(model_info)
261244 dotargs = getargs_dottilde (x)
262245 if dotargs != = nothing
263246 L, R = dotargs
264- return generate_dot_tilde (L, R, model_info)
247+ return Base . remove_linenums! ( generate_dot_tilde (L, R, model_info) )
265248 end
266249
267250 # Check tilde.
268251 args = getargs_tilde (x)
269252 if args != = nothing
270253 L, R = args
271- return generate_tilde (L, R, model_info)
254+ return Base . remove_linenums! ( generate_tilde (L, R, model_info) )
272255 end
273256
274257 return x
@@ -294,45 +277,55 @@ function generate_tilde(left, right, model_info)
294277 vi = model_info[:main_body_names ][:vi ]
295278 ctx = model_info[:main_body_names ][:ctx ]
296279 sampler = model_info[:main_body_names ][:sampler ]
297- temp_right = gensym (:temp_right )
298- out = gensym (:out )
299- lp = gensym (:lp )
300- vn = gensym (:vn )
301- inds = gensym (:inds )
302- isassumption = gensym (:isassumption )
303- assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
304-
280+
281+ @gensym tmpright
282+ top = [:($ tmpright = $ right),
283+ :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
284+ || throw (ArgumentError ($ DISTMSG)))]
285+
305286 if left isa Symbol || left isa Expr
306- ex = quote
307- $ temp_right = $ right
308- $ assert_ex
309-
310- $ vn, $ inds = $ (varname (left)), $ (vinds (left))
311- $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
312- if $ isassumption
313- $ out = DynamicPPL. tilde_assume ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
314- $ left = $ out[1 ]
315- DynamicPPL. acclogp! ($ vi, $ out[2 ])
316- else
317- DynamicPPL. acclogp! (
318- $ vi,
319- DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi),
320- )
287+ @gensym out vn inds
288+ push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
289+
290+ assumption = [
291+ :($ out = $ (DynamicPPL. tilde_assume)($ ctx, $ sampler, $ tmpright, $ vn, $ inds,
292+ $ vi)),
293+ :($ left = $ out[1 ]),
294+ :($ (DynamicPPL. acclogp!)($ vi, $ out[2 ]))
295+ ]
296+
297+ # It can only be an observation if the LHS is an argument of the model
298+ if vsym (left) in model_info[:args ]
299+ @gensym isassumption
300+ return quote
301+ $ (top... )
302+ $ isassumption = $ (DynamicPPL. isassumption (model, left))
303+ if $ isassumption
304+ $ (assumption... )
305+ else
306+ $ (DynamicPPL. acclogp!)(
307+ $ vi,
308+ $ (DynamicPPL. tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vn,
309+ $ inds, $ vi)
310+ )
311+ end
321312 end
322313 end
323- else
324- # we have a literal, which is automatically an observation
325- ex = quote
326- $ temp_right = $ right
327- $ assert_ex
328-
329- DynamicPPL. acclogp! (
330- $ vi,
331- DynamicPPL. tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
332- )
314+
315+ return quote
316+ $ (top... )
317+ $ (assumption... )
333318 end
334319 end
335- return ex
320+
321+ # If the LHS is a literal, it is always an observation
322+ return quote
323+ $ (top... )
324+ $ (DynamicPPL. acclogp!)(
325+ $ vi,
326+ $ (DynamicPPL. tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vi)
327+ )
328+ end
336329end
337330
338331"""
@@ -347,46 +340,55 @@ function generate_dot_tilde(left, right, model_info)
347340 vi = model_info[:main_body_names ][:vi ]
348341 ctx = model_info[:main_body_names ][:ctx ]
349342 sampler = model_info[:main_body_names ][:sampler ]
350- out = gensym (:out )
351- temp_right = gensym (:temp_right )
352- isassumption = gensym (:isassumption )
353- lp = gensym (:lp )
354- vn = gensym (:vn )
355- inds = gensym (:inds )
356- assert_ex = :(DynamicPPL. assert_dist ($ temp_right, msg = $ (wrong_dist_errormsg (@__LINE__ ))))
357-
343+
344+ @gensym tmpright
345+ top = [:($ tmpright = $ right),
346+ :($ tmpright isa Union{$ Distribution,AbstractVector{<: $Distribution }}
347+ || throw (ArgumentError ($ DISTMSG)))]
348+
358349 if left isa Symbol || left isa Expr
359- ex = quote
360- $ temp_right = $ right
361- $ assert_ex
362-
363- $ vn, $ inds = $ (varname (left)), $ (vinds (left))
364- $ isassumption = DynamicPPL. @isassumption ($ model, $ left)
365-
366- if $ isassumption
367- $ out = DynamicPPL. dot_tilde_assume ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi)
368- $ left .= $ out[1 ]
369- DynamicPPL. acclogp! ($ vi, $ out[2 ])
370- else
371- DynamicPPL. acclogp! (
372- $ vi,
373- DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vn, $ inds, $ vi),
374- )
350+ @gensym out vn inds
351+ push! (top, :($ vn = $ (varname (left))), :($ inds = $ (vinds (left))))
352+
353+ assumption = [
354+ :($ out = $ (DynamicPPL. dot_tilde_assume)($ ctx, $ sampler, $ tmpright, $ left,
355+ $ vn, $ inds, $ vi)),
356+ :($ left .= $ out[1 ]),
357+ :($ (DynamicPPL. acclogp!)($ vi, $ out[2 ]))
358+ ]
359+
360+ # It can only be an observation if the LHS is an argument of the model
361+ if vsym (left) in model_info[:args ]
362+ @gensym isassumption
363+ return quote
364+ $ (top... )
365+ $ isassumption = $ (DynamicPPL. isassumption (model, left))
366+ if $ isassumption
367+ $ (assumption... )
368+ else
369+ $ (DynamicPPL. acclogp!)(
370+ $ vi,
371+ $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ left,
372+ $ vn, $ inds, $ vi)
373+ )
374+ end
375375 end
376376 end
377- else
378- # we have a literal, which is automatically an observation
379- ex = quote
380- $ temp_right = $ right
381- $ assert_ex
382-
383- DynamicPPL. acclogp! (
384- $ vi,
385- DynamicPPL. dot_tilde_observe ($ ctx, $ sampler, $ temp_right, $ left, $ vi),
386- )
377+
378+ return quote
379+ $ (top... )
380+ $ (assumption... )
387381 end
388382 end
389- return ex
383+
384+ # If the LHS is a literal, it is always an observation
385+ return quote
386+ $ (top... )
387+ $ (DynamicPPL. acclogp!)(
388+ $ vi,
389+ $ (DynamicPPL. dot_tilde_observe)($ ctx, $ sampler, $ tmpright, $ left, $ vi)
390+ )
391+ end
390392end
391393
392394const FloatOrArrayType = Type{<: Union{AbstractFloat, AbstractArray} }
@@ -425,42 +427,29 @@ function build_output(model_info)
425427
426428 unwrap_data_expr = Expr (:block )
427429 for var in arg_syms
428- temp_var = gensym (:temp_var )
429- varT = gensym (:varT )
430- push! (unwrap_data_expr. args, quote
431- local $ var
432- $ temp_var = $ model. args.$ var
433- $ varT = typeof ($ temp_var)
434- if $ temp_var isa DynamicPPL. FloatOrArrayType
435- $ var = DynamicPPL. get_matching_type ($ sampler, $ vi, $ temp_var)
436- elseif DynamicPPL. hasmissing ($ varT)
437- $ var = DynamicPPL. get_matching_type ($ sampler, $ vi, $ varT)($ temp_var)
438- else
439- $ var = $ temp_var
440- end
441- end )
430+ push! (unwrap_data_expr. args,
431+ :($ var = $ (DynamicPPL. matchingvalue)($ sampler, $ vi, $ (model). args.$ var)))
442432 end
443433
444434 @gensym (evaluator, generator)
445435 generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
446- model_gen_constructor = :(DynamicPPL. ModelGen {$(Tuple(arg_syms))} ($ generator, $ defaults_nt))
447-
436+ model_gen_constructor = :($ ( DynamicPPL. ModelGen) {$ (Tuple (arg_syms))}($ generator, $ defaults_nt))
437+
448438 ex = quote
449439 function $evaluator (
450- $ model:: Model ,
451- $ vi:: DynamicPPL.VarInfo ,
452- $ sampler:: DynamicPPL.AbstractSampler ,
453- $ ctx:: DynamicPPL.AbstractContext ,
440+ $ model:: $ (DynamicPPL . Model) ,
441+ $ vi:: $ ( DynamicPPL. VarInfo) ,
442+ $ sampler:: $ ( DynamicPPL. AbstractSampler) ,
443+ $ ctx:: $ ( DynamicPPL. AbstractContext) ,
454444 )
455445 $ unwrap_data_expr
456- DynamicPPL. resetlogp! ($ vi)
446+ $ ( DynamicPPL. resetlogp!) ($ vi)
457447 $ main_body
458448 end
459-
460449
461- $ generator ($ (args... )) = DynamicPPL. Model ($ evaluator, $ args_nt, $ model_gen_constructor)
450+ $ generator ($ (args... )) = $ ( DynamicPPL. Model) ($ evaluator, $ args_nt, $ model_gen_constructor)
462451 $ (generator_kw_form... )
463-
452+
464453 $ model_gen = $ model_gen_constructor
465454 end
466455
@@ -475,6 +464,21 @@ function warn_empty(body)
475464 return
476465end
477466
467+ """
468+ matchingvalue(sampler, vi, value)
469+
470+ Convert the `value` to the correct type for the `sampler` and the `vi` object.
471+ """
472+ function matchingvalue (sampler, vi, value)
473+ T = typeof (value)
474+ if hasmissing (T)
475+ return get_matching_type (sampler, vi, T)(value)
476+ else
477+ return value
478+ end
479+ end
480+ matchingvalue (sampler, vi, value:: FloatOrArrayType ) = get_matching_type (sampler, vi, value)
481+
478482"""
479483 get_matching_type(spl, vi, ::Type{T}) where {T}
480484Get the specialized version of type `T` for sampler `spl`. For example,
0 commit comments