11const INTERNALNAMES = (:__model__ , :__sampler__ , :__context__ , :__varinfo__ , :__rng__ )
22const DEPRECATED_INTERNALNAMES = (:_model , :_sampler , :_context , :_varinfo , :_rng )
33
4- """
5- isassumption(expr)
6-
7- Return an expression that can be evaluated to check if `expr` is an assumption in the
8- model.
9-
10- Let `expr` be `:(x[1])`. It is an assumption in the following cases:
11- 1. `x` is not among the input data to the model,
12- 2. `x` is among the input data to the model but with a value `missing`, or
13- 3. `x` is among the input data to the model with a value other than missing,
14- but `x[1] === missing`.
15-
16- When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
17- """
18- function isassumption (expr:: Union{Symbol,Expr} )
19- vn = gensym (:vn )
20-
21- return quote
22- let $ vn = $ (varname (expr))
23- # This branch should compile nicely in all cases except for partial missing data
24- # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
25- if ! $ (DynamicPPL. inargnames)($ vn, __model__) ||
26- $ (DynamicPPL. inmissings)($ vn, __model__)
27- true
28- else
29- # Evaluate the LHS
30- $ expr === missing
31- end
32- end
33- end
34- end
35-
36- # failsafe: a literal is never an assumption
37- isassumption (expr) = :(false )
384
395"""
406 isliteral(expr)
137103function model (mod, linenumbernode, expr, warn)
138104 modelinfo = build_model_info (expr)
139105
140- # Generate main body
141- modelinfo[:body ] = generate_mainbody (mod, modelinfo[:modeldef ][:body ], warn)
106+ # Generate main body and find all variable symbols
107+ modelinfo[:body ], modelinfo[:varnames ] = generate_mainbody (
108+ mod, modelinfo[:modeldef ][:body ], warn
109+ )
110+
111+ # extract observations from that
112+ modelinfo[:obsnames ] = modelinfo[:allargs_syms ] ∩ modelinfo[:varnames ]
113+ modelinfo[:latentnames ] = setdiff (modelinfo[:varnames ], modelinfo[:allargs_syms ])
142114
143115 return build_output (modelinfo, linenumbernode)
144116end
@@ -167,8 +139,7 @@ function build_model_info(input_expr)
167139 modelinfo = Dict (
168140 :allargs_exprs => [],
169141 :allargs_syms => [],
170- :allargs_namedtuple => NamedTuple (),
171- :defaults_namedtuple => NamedTuple (),
142+ :allargs_defaults => [],
172143 :modeldef => modeldef,
173144 )
174145 return modelinfo
@@ -177,17 +148,18 @@ function build_model_info(input_expr)
177148 # Extract the positional and keyword arguments from the model definition.
178149 allargs = vcat (modeldef[:args ], modeldef[:kwargs ])
179150
180- # Split the argument expressions and the default values.
181- allargs_exprs_defaults = map (allargs) do arg
182- MacroTools. @match arg begin
151+ # Split the argument expressions and the default values, by unzipping allargs, taking care of
152+ # the empty case
153+ allargs_exprs, allargs_defaults = foldl (allargs; init= ([], [])) do (ae, ad), arg
154+ (expr, default) = MacroTools. @match arg begin
183155 (x_ = val_) => (x, val)
184156 x_ => (x, NO_DEFAULT)
185157 end
158+ push! (ae, expr)
159+ push! (ad, default)
160+ ae, ad
186161 end
187-
188- # Extract the expressions of the arguments, without default values.
189- allargs_exprs = first .(allargs_exprs_defaults)
190-
162+
191163 # Extract the names of the arguments.
192164 allargs_syms = map (allargs_exprs) do arg
193165 MacroTools. @match arg begin
@@ -196,28 +168,11 @@ function build_model_info(input_expr)
196168 x_ => x
197169 end
198170 end
199-
200- # Build named tuple expression of the argument symbols and variables of the same name.
201- allargs_namedtuple = to_namedtuple_expr (allargs_syms)
202-
203- # Extract default values of the positional and keyword arguments.
204- default_syms = []
205- default_vals = []
206- for (sym, (expr, val)) in zip (allargs_syms, allargs_exprs_defaults)
207- if val != = NO_DEFAULT
208- push! (default_syms, sym)
209- push! (default_vals, val)
210- end
211- end
212-
213- # Build named tuple expression of the argument symbols with default values.
214- defaults_namedtuple = to_namedtuple_expr (default_syms, default_vals)
215-
171+
216172 modelinfo = Dict (
217173 :allargs_exprs => allargs_exprs,
218174 :allargs_syms => allargs_syms,
219- :allargs_namedtuple => allargs_namedtuple,
220- :defaults_namedtuple => defaults_namedtuple,
175+ :allargs_defaults => allargs_defaults,
221176 :modeldef => modeldef,
222177 )
223178
@@ -233,43 +188,50 @@ Generate the body of the main evaluation function from expression `expr` and arg
233188If `warn` is true, a warning is displayed if internal variables are used in the model
234189definition.
235190"""
236- generate_mainbody (mod, expr, warn) = generate_mainbody! (mod, Symbol[], expr, warn)
191+ function generate_mainbody (mod, expr, warn)
192+ varnames = Symbol[]
193+ body = generate_mainbody! (mod, Symbol[], varnames, expr, warn)
194+ return body, varnames
195+ end
237196
238- generate_mainbody! (mod, found , x, warn) = x
239- function generate_mainbody! (mod, found , sym:: Symbol , warn)
197+ generate_mainbody! (mod, found_internals, varnames , x, warn) = x
198+ function generate_mainbody! (mod, found_internals , sym:: Symbol , warn)
240199 if sym in DEPRECATED_INTERNALNAMES
241200 newsym = Symbol (:_ , sym, :__ )
242201 Base. depwarn (
243202 " internal variable `$sym ` is deprecated, use `$newsym ` instead." ,
244203 :generate_mainbody! ,
245204 )
246- return generate_mainbody! (mod, found , newsym, warn)
205+ return generate_mainbody! (mod, found_internals , newsym, warn)
247206 end
248207
249- if warn && sym in INTERNALNAMES && sym ∉ found
208+ if warn && sym in INTERNALNAMES && sym ∉ found_internals
250209 @warn " you are using the internal variable `$sym `"
251- push! (found , sym)
210+ push! (found_internals , sym)
252211 end
253212
254213 return sym
255214end
256- function generate_mainbody! (mod, found , expr:: Expr , warn)
215+ function generate_mainbody! (mod, found_internals, varnames , expr:: Expr , warn)
257216 # Do not touch interpolated expressions
258217 expr. head === :$ && return expr. args[1 ]
259218
260219 # If it's a macro, we expand it
261220 if Meta. isexpr (expr, :macrocall )
262- return generate_mainbody! (mod, found, macroexpand (mod, expr; recursive= true ), warn)
221+ return generate_mainbody! (
222+ mod, found_internals, varnames, macroexpand (mod, expr; recursive= true ), warn
223+ )
263224 end
264225
265226 # Modify dotted tilde operators.
266227 args_dottilde = getargs_dottilde (expr)
267228 if args_dottilde != = nothing
268229 L, R = args_dottilde
230+ ! isliteral (L) && push! (varnames, vsym (L))
269231 return Base. remove_linenums! (
270232 generate_dot_tilde (
271- generate_mainbody! (mod, found , L, warn),
272- generate_mainbody! (mod, found , R, warn),
233+ generate_mainbody! (mod, found_internals, varnames , L, warn),
234+ generate_mainbody! (mod, found_internals, varnames , R, warn),
273235 ),
274236 )
275237 end
@@ -278,15 +240,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
278240 args_tilde = getargs_tilde (expr)
279241 if args_tilde != = nothing
280242 L, R = args_tilde
243+ ! isliteral (L) && push! (varnames, vsym (L))
281244 return Base. remove_linenums! (
282245 generate_tilde (
283- generate_mainbody! (mod, found , L, warn),
284- generate_mainbody! (mod, found , R, warn),
246+ generate_mainbody! (mod, found_internals, varnames , L, warn),
247+ generate_mainbody! (mod, found_internals, varnames , R, warn),
285248 ),
286249 )
287250 end
288251
289- return Expr (expr. head, map (x -> generate_mainbody! (mod, found, x, warn), expr. args)... )
252+ return Expr (
253+ expr. head,
254+ map (x -> generate_mainbody! (mod, found_internals, varnames, x, warn), expr. args)... ,
255+ )
290256end
291257
292258"""
@@ -307,26 +273,26 @@ function generate_tilde(left, right)
307273
308274 # Otherwise it is determined by the model or its value,
309275 # if the LHS represents an observation
310- @gensym vn inds isassumption
276+ @gensym vn inds isobservation
311277 return quote
312278 $ vn = $ (varname (left))
313279 $ inds = $ (vinds (left))
314- $ isassumption = $ (DynamicPPL. isassumption (left) )
315- if $ isassumption
316- $ left = $ (DynamicPPL. tilde_assume !)(
280+ $ isobservation = $ (DynamicPPL. isobservation)( $ vn, __model__ )
281+ if $ isobservation
282+ $ (DynamicPPL. tilde_observe !)(
317283 __context__,
318- $ (DynamicPPL. unwrap_right_vn)(
319- $ (DynamicPPL . check_tilde_rhs)( $ right), $ vn
320- ) . .. ,
284+ $ (DynamicPPL. check_tilde_rhs)( $ right),
285+ $ left,
286+ $ vn ,
321287 $ inds,
322288 __varinfo__,
323289 )
324290 else
325- $ (DynamicPPL. tilde_observe !)(
291+ $ left = $ (DynamicPPL. tilde_assume !)(
326292 __context__,
327- $ (DynamicPPL. check_tilde_rhs)( $ right),
328- $ left,
329- $ vn ,
293+ $ (DynamicPPL. unwrap_right_vn)(
294+ $ (DynamicPPL . check_tilde_rhs)( $ right), $ vn
295+ ) . .. ,
330296 $ inds,
331297 __varinfo__,
332298 )
@@ -351,26 +317,26 @@ function generate_dot_tilde(left, right)
351317
352318 # Otherwise it is determined by the model or its value,
353319 # if the LHS represents an observation
354- @gensym vn inds isassumption
320+ @gensym vn inds isobservation
355321 return quote
356322 $ vn = $ (varname (left))
357323 $ inds = $ (vinds (left))
358- $ isassumption = $ (DynamicPPL. isassumption (left) )
359- if $ isassumption
360- $ left . = $ (DynamicPPL. dot_tilde_assume !)(
324+ $ isobservation = $ (DynamicPPL. isobservation)( $ vn, __model__ )
325+ if $ isobservation
326+ $ (DynamicPPL. dot_tilde_observe !)(
361327 __context__,
362- $ (DynamicPPL. unwrap_right_left_vns)(
363- $ (DynamicPPL . check_tilde_rhs)( $ right), $ left, $ vn
364- ) . .. ,
328+ $ (DynamicPPL. check_tilde_rhs)( $ right),
329+ $ left,
330+ $ vn ,
365331 $ inds,
366332 __varinfo__,
367333 )
368334 else
369- $ (DynamicPPL. dot_tilde_observe !)(
335+ $ left . = $ (DynamicPPL. dot_tilde_assume !)(
370336 __context__,
371- $ (DynamicPPL. check_tilde_rhs)( $ right),
372- $ left,
373- $ vn ,
337+ $ (DynamicPPL. unwrap_right_left_vns)(
338+ $ (DynamicPPL . check_tilde_rhs)( $ right), $ left, $ vn
339+ ) . .. ,
374340 $ inds,
375341 __varinfo__,
376342 )
@@ -413,10 +379,24 @@ function build_output(modelinfo, linenumbernode)
413379
414380 # # Build the model function.
415381
416- # Extract the named tuple expression of all arguments and the default values.
417- allargs_namedtuple = modelinfo[:allargs_namedtuple ]
418- defaults_namedtuple = modelinfo[:defaults_namedtuple ]
382+ # Extract the named tuple expression of all arguments
383+ allargs_newnames = [gensym (x) for x in modelinfo[:allargs_syms ]]
384+ allargs_wrapped = map (modelinfo[:allargs_syms ], modelinfo[:allargs_defaults ]) do x, d
385+ if x ∈ modelinfo[:obsnames ]
386+ :($ (DynamicPPL. Variable)($ x, $ d))
387+ else
388+ :($ (DynamicPPL. Constant)($ x, $ d))
389+ end
390+ end
391+ allargs_decls = [:($ name = $ val) for (name, val) in zip (allargs_newnames, allargs_wrapped)]
392+ allargs_namedtuple = to_namedtuple_expr (modelinfo[:allargs_syms ], allargs_newnames)
419393
394+ internals_newnames = [gensym (x) for x in modelinfo[:latentnames ]]
395+ internals_decls = map (internals_newnames) do name
396+ :($ name = $ (DynamicPPL. Variable)(missing ))
397+ end
398+ internals_namedtuple = to_namedtuple_expr (modelinfo[:latentnames ], internals_newnames)
399+
420400 # Update the function body of the user-specified model.
421401 # We use a name for the anonymous evaluator that does not conflict with other variables.
422402 modeldef = modelinfo[:modeldef ]
@@ -427,11 +407,13 @@ function build_output(modelinfo, linenumbernode)
427407 modeldef[:body ] = MacroTools. @q begin
428408 $ (linenumbernode)
429409 $ evaluator = $ (MacroTools. combinedef (evaluatordef))
410+ $ (allargs_decls... )
411+ $ (internals_decls... )
430412 return $ (DynamicPPL. Model)(
431413 $ (QuoteNode (modeldef[:name ])),
432414 $ evaluator,
433415 $ allargs_namedtuple,
434- $ defaults_namedtuple ,
416+ $ internals_namedtuple ,
435417 )
436418 end
437419
0 commit comments