1- """
2- struct ModelGen{Targs, F, Tdefaults} <: Function
3- f::F
4- defaults::Tdefaults
5- end
6-
7- A `Model` generator. This is the output of the `@model` macro. `Targs` is the tuple
8- of the symbols of the model's arguments. `defaults` is the `NamedTuple` of default values
9- of the arguments, if any. Every `ModelGen` is callable with the arguments `Targs`,
10- returning an instance of `Model`.
11- """
12- struct ModelGen{Targs, F, Tdefaults} <: Function
13- f:: F
14- defaults:: Tdefaults
15- end
16- ModelGen {Targs} (args... ) where {Targs} = ModelGen {Targs, typeof.(args)...} (args... )
17- (m:: ModelGen )(args... ; kwargs... ) = m. f (args... ; kwargs... )
18- function Base. getproperty (m:: ModelGen{Targs} , f:: Symbol ) where {Targs}
19- f === :args && return Targs
20- return Base. getfield (m, f)
21- end
22-
231macro varinfo ()
242 :(throw (_error_msg ()))
253end
@@ -61,18 +39,18 @@ Otherwise, the value of `x[1]` is returned.
6139macro preprocess (data_vars, missing_vars, ex)
6240 ex
6341end
64- macro preprocess (data_vars, missing_vars , ex:: Union{Symbol, Expr} )
42+ macro preprocess (model , ex:: Union{Symbol, Expr} )
6543 sym = gensym (:sym )
6644 lhs = gensym (:lhs )
6745 return esc (quote
6846 # Extract symbol
6947 $ sym = Val ($ (vsym (ex)))
7048 # This branch should compile nicely in all cases except for partial missing data
7149 # For example, when `ex` is `x[i]` and `x isa Vector{Union{Missing, Float64}}`
72- if ! DynamicPPL. inparams ($ sym, $ data_vars ) || DynamicPPL. inparams ($ sym, $ missing_vars )
50+ if ! DynamicPPL. inargnames ($ sym, $ model ) || DynamicPPL. inmissings ($ sym, $ model )
7351 $ (varname (ex)), $ (vinds (ex))
7452 else
75- if DynamicPPL. inparams ($ sym, $ data_vars )
53+ if DynamicPPL. inargnames ($ sym, $ model )
7654 # Evaluate the lhs
7755 $ lhs = $ ex
7856 if $ lhs === missing
@@ -86,9 +64,7 @@ macro preprocess(data_vars, missing_vars, ex::Union{Symbol, Expr})
8664 end
8765 end )
8866end
89- @generated function inparams (:: Val{s} , :: Val{t} ) where {s, t}
90- return (s in t) ? :(true ) : :(false )
91- end
67+
9268
9369# ################
9470# Main Compiler #
@@ -151,7 +127,7 @@ function build_model_info(input_expr)
151127 else
152128 nt_type = Expr (:curly , :NamedTuple ,
153129 Expr (:tuple , QuoteNode .(arg_syms)... ),
154- Expr (:curly , :Tuple , [:(DynamicPPL . get_type ($ x)) for x in arg_syms]. .. )
130+ Expr (:curly , :Tuple , [:(Core . Typeof ($ x)) for x in arg_syms]. .. )
155131 )
156132 args_nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , arg_syms... ))
157133 end
@@ -205,27 +181,13 @@ function build_model_info(input_expr)
205181 :ctx => gensym (:ctx ),
206182 :vi => gensym (:vi ),
207183 :sampler => gensym (:sampler ),
208- :model => gensym (:model ),
209- :inner_function => gensym (:inner_function ),
210- :defaults => gensym (:defaults )
184+ :model => gensym (:model )
211185 )
212186 )
213187
214188 return model_info
215189end
216190
217- function to_namedtuple_expr (syms:: Vector , vals = syms)
218- if length (syms) == 0
219- nt = :(NamedTuple ())
220- else
221- nt_type = Expr (:curly , :NamedTuple ,
222- Expr (:tuple , QuoteNode .(syms)... ),
223- Expr (:curly , :Tuple , [:(DynamicPPL. get_type ($ x)) for x in vals]. .. )
224- )
225- nt = Expr (:call , :(DynamicPPL. namedtuple), nt_type, Expr (:tuple , vals... ))
226- end
227- return nt
228- end
229191
230192"""
231193 replace_vi!(model_info)
@@ -319,14 +281,16 @@ function replace_tilde!(model_info)
319281end
320282""" |> Meta. parse |> eval
321283
284+ # """ Unbreak code highlighting in Emacs julia-mode
285+
286+
322287"""
323288 generate_tilde(left, right, model_info)
324289
325290The `tilde` function generates `observe` expression for data variables and `assume`
326291expressions for parameter variables, updating `model_info` in the process.
327292"""
328293function generate_tilde (left, right, model_info)
329- arg_syms = Val ((model_info[:arg_syms ]. .. ,))
330294 model = model_info[:main_body_names ][:model ]
331295 vi = model_info[:main_body_names ][:vi ]
332296 ctx = model_info[:main_body_names ][:ctx ]
@@ -342,7 +306,7 @@ function generate_tilde(left, right, model_info)
342306 ex = quote
343307 $ temp_right = $ right
344308 $ assert_ex
345- $ preprocessed = DynamicPPL. @preprocess ($ arg_syms, DynamicPPL . getmissing ( $ model) , $ left)
309+ $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
346310 if $ preprocessed isa Tuple
347311 $ vn, $ inds = $ preprocessed
348312 $ out = DynamicPPL. tilde ($ ctx, $ sampler, $ temp_right, $ vn, $ inds, $ vi)
374338This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
375339"""
376340function generate_dot_tilde (left, right, model_info)
377- arg_syms = Val ((model_info[:arg_syms ]. .. ,))
378341 model = model_info[:main_body_names ][:model ]
379342 vi = model_info[:main_body_names ][:vi ]
380343 ctx = model_info[:main_body_names ][:ctx ]
@@ -391,7 +354,7 @@ function generate_dot_tilde(left, right, model_info)
391354 ex = quote
392355 $ temp_right = $ right
393356 $ assert_ex
394- $ preprocessed = DynamicPPL. @preprocess ($ arg_syms, DynamicPPL . getmissing ( $ model) , $ left)
357+ $ preprocessed = DynamicPPL. @preprocess ($ model, $ left)
395358 if $ preprocessed isa Tuple
396359 $ vn, $ inds = $ preprocessed
397360 $ temp_left = $ left
@@ -437,7 +400,6 @@ function build_output(model_info)
437400 vi = main_body_names[:vi ]
438401 model = main_body_names[:model ]
439402 sampler = main_body_names[:sampler ]
440- inner_function = main_body_names[:inner_function ]
441403
442404 # Arguments with default values
443405 args = model_info[:args ]
@@ -452,16 +414,9 @@ function build_output(model_info)
452414 whereparams = model_info[:whereparams ]
453415 # Model generator name
454416 model_gen = model_info[:name ]
455- # Outer function name
456- outer_function = gensym (model_info[:name ])
457417 # Main body of the model
458418 main_body = model_info[:main_body ]
459- model_gen_constructor = quote
460- DynamicPPL. ModelGen {$(Tuple(arg_syms))} (
461- $ outer_function,
462- $ defaults_nt,
463- )
464- end
419+
465420 unwrap_data_expr = Expr (:block )
466421 for var in arg_syms
467422 temp_var = gensym (:temp_var )
@@ -480,40 +435,32 @@ function build_output(model_info)
480435 end )
481436 end
482437
438+ @gensym (evaluator, generator)
439+ generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
440+ model_gen_constructor = :(DynamicPPL. ModelGen {$(Tuple(arg_syms))} ($ generator, $ defaults_nt))
441+
483442 ex = quote
484- function $outer_function ($ (args... ))
485- function $inner_function (
486- $ vi:: DynamicPPL.VarInfo ,
487- $ sampler:: DynamicPPL.AbstractSampler ,
488- $ ctx:: DynamicPPL.AbstractContext ,
489- $ model
490- )
491- $ unwrap_data_expr
492- DynamicPPL. resetlogp! ($ vi)
493- $ main_body
494- end
495- return DynamicPPL. Model ($ inner_function, $ args_nt, $ model_gen_constructor)
443+ function $evaluator (
444+ $ model:: Model ,
445+ $ vi:: DynamicPPL.VarInfo ,
446+ $ sampler:: DynamicPPL.AbstractSampler ,
447+ $ ctx:: DynamicPPL.AbstractContext ,
448+ )
449+ $ unwrap_data_expr
450+ DynamicPPL. resetlogp! ($ vi)
451+ $ main_body
496452 end
497- $ model_gen = $ model_gen_constructor
498- end
453+
499454
500- if ! isempty (args)
501- ex = quote
502- $ ex
503- # Allows passing arguments as kwargs
504- $ outer_function (;$ (args... )) = $ outer_function ($ (arg_syms... ))
505- end
455+ $ generator ($ (args... )) = DynamicPPL. Model ($ evaluator, $ args_nt, $ model_gen_constructor)
456+ $ (generator_kw_form... )
457+
458+ $ model_gen = $ model_gen_constructor
506459 end
507460
508461 return esc (ex)
509462end
510463
511- # A hack for NamedTuple type specialization
512- # (T = Int,) has type NamedTuple{(:T,), Tuple{DataType}} by default
513- # With this function, we can make it NamedTuple{(:T,), Tuple{Type{Int}}}
514- # Both are correct, but the latter is what we want for type stability
515- get_type (:: Type{T} ) where {T} = Type{T}
516- get_type (t) = typeof (t)
517464
518465function warn_empty (body)
519466 if all (l -> isa (l, LineNumberNode), body. args)
0 commit comments