6969
7070To generate a `Model`, call `model_generator(x_value)`.
7171"""
72- macro model (input_expr)
73- Base. replace_ref_end! (input_expr) |> build_model_info |> replace_tilde! |> replace_vi! |>
74- replace_logpdf! |> replace_sampler! |> build_output
72+ macro model (expr)
73+ esc (model (expr))
74+ end
75+
76+ function model (expr)
77+ modelinfo = build_model_info (expr)
78+
79+ # Generate main body
80+ modelinfo[:main_body ] = generate_mainbody (modelinfo)
81+
82+ return build_output (modelinfo)
7583end
7684
7785"""
@@ -170,62 +178,55 @@ function build_model_info(input_expr)
170178 return model_info
171179end
172180
173-
174181"""
175- replace_vi!(model_info )
182+ generate_mainbody([expr, ]modelinfo )
176183
177- Replaces `@varinfo()` expressions with a handle to the `VarInfo` struct .
184+ Generate the body of the main evaluation function .
178185"""
179- function replace_vi! (model_info)
180- ex = model_info[:main_body ]
181- vi = model_info[:main_body_names ][:vi ]
182- ex = MacroTools. postwalk (ex) do x
183- if @capture (x, @varinfo ())
184- vi
185- else
186- x
187- end
188- end
189- model_info[:main_body ] = ex
190- return model_info
191- end
186+ generate_mainbody (modelinfo) = generate_mainbody (modelinfo[:main_body ], modelinfo)
192187
193- """
194- replace_logpdf!(model_info)
188+ generate_mainbody (x, modelinfo) = x
189+ function generate_mainbody (expr:: Expr , modelinfo)
190+ # Do not touch interpolated expressions
191+ expr. head === :$ && return expr. args[1 ]
195192
196- Replaces `@logpdf()` expressions with the value of the accumulated `logpdf` in the `VarInfo` struct.
197- """
198- function replace_logpdf! (model_info)
199- ex = model_info[:main_body ]
200- vi = model_info[:main_body_names ][:vi ]
201- ex = MacroTools. postwalk (ex) do x
202- if @capture (x, @logpdf ())
203- :($ (vi). logp[])
204- else
205- x
193+ # Apply the `@.` macro first.
194+ if Meta. isexpr (expr, :macrocall ) && length (expr. args) > 1 &&
195+ expr. args[1 ] === Symbol (" @__dot__" )
196+ return generate_mainbody (Base. Broadcast. __dot__ (expr. args[end ]), modelinfo)
197+ end
198+
199+ # Modify macro calls.
200+ if Meta. isexpr (expr, :macrocall ) && ! isempty (expr. args)
201+ name = expr. args[1 ]
202+ if name === Symbol (" @varinfo" )
203+ return modelinfo[:main_body_names ][:vi ]
204+ elseif name === Symbol (" @logpdf" )
205+ return :($ (modelinfo[:main_body_names ][:vi ]). logp[])
206+ elseif name === Symbol (" @sampler" )
207+ return :($ (modelinfo[:main_body_names ][:sampler ]))
206208 end
207209 end
208- model_info[:main_body ] = ex
209- return model_info
210- end
211210
212- """
213- replace_sampler!(model_info)
211+ # Modify dotted tilde operators.
212+ args_dottilde = getargs_dottilde (expr)
213+ if args_dottilde != = nothing
214+ L, R = args_dottilde
215+ return Base. remove_linenums! (generate_dot_tilde (generate_mainbody (L, modelinfo),
216+ generate_mainbody (R, modelinfo),
217+ modelinfo))
218+ end
214219
215- Replaces `@sampler()` expressions with a handle to the sampler struct.
216- """
217- function replace_sampler! (model_info)
218- ex = model_info[:main_body ]
219- spl = model_info[:main_body_names ][:sampler ]
220- ex = MacroTools. postwalk (ex) do x
221- if @capture (x, @sampler ())
222- spl
223- else
224- x
225- end
220+ # Modify tilde operators.
221+ args_tilde = getargs_tilde (expr)
222+ if args_tilde != = nothing
223+ L, R = args_tilde
224+ return Base. remove_linenums! (generate_tilde (generate_mainbody (L, modelinfo),
225+ generate_mainbody (R, modelinfo),
226+ modelinfo))
226227 end
227- model_info[ :main_body ] = ex
228- return model_info
228+
229+ return Expr (expr . head, map (x -> generate_mainbody (x, modelinfo), expr . args) ... )
229230end
230231
231232"""
@@ -443,7 +444,7 @@ function build_output(model_info)
443444 generator_kw_form = isempty (args) ? () : (:($ generator (;$ (args... )) = $ generator ($ (arg_syms... ))),)
444445 model_gen_constructor = :($ (DynamicPPL. ModelGen){$ (Tuple (arg_syms))}($ generator, $ defaults_nt))
445446
446- ex = quote
447+ return quote
447448 function $evaluator (
448449 $ model:: $ (DynamicPPL. Model),
449450 $ vi:: $ (DynamicPPL. VarInfo),
@@ -459,8 +460,6 @@ function build_output(model_info)
459460
460461 $ model_gen = $ model_gen_constructor
461462 end
462-
463- return esc (ex)
464463end
465464
466465
0 commit comments