Skip to content

Commit 6de3d25

Browse files
authored
Try #268:
2 parents 222091e + 8c1a1f4 commit 6de3d25

File tree

8 files changed

+272
-240
lines changed

8 files changed

+272
-240
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ export AbstractVarInfo,
6767
vectorize,
6868
# Model
6969
Model,
70-
getmissings,
71-
getargnames,
70+
getargumentnames,
7271
generated_quantities,
7372
# Samplers
7473
Sampler,

src/compiler.jl

Lines changed: 85 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,6 @@
11
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
22
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)
33

4-
"""
5-
isassumption(expr)
6-
7-
Return an expression that can be evaluated to check if `expr` is an assumption in the
8-
model.
9-
10-
Let `expr` be `:(x[1])`. It is an assumption in the following cases:
11-
1. `x` is not among the input data to the model,
12-
2. `x` is among the input data to the model but with a value `missing`, or
13-
3. `x` is among the input data to the model with a value other than missing,
14-
but `x[1] === missing`.
15-
16-
When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
17-
"""
18-
function isassumption(expr::Union{Symbol,Expr})
19-
vn = gensym(:vn)
20-
21-
return quote
22-
let $vn = $(varname(expr))
23-
# This branch should compile nicely in all cases except for partial missing data
24-
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
25-
if !$(DynamicPPL.inargnames)($vn, __model__) ||
26-
$(DynamicPPL.inmissings)($vn, __model__)
27-
true
28-
else
29-
# Evaluate the LHS
30-
$expr === missing
31-
end
32-
end
33-
end
34-
end
35-
36-
# failsafe: a literal is never an assumption
37-
isassumption(expr) = :(false)
384

395
"""
406
isliteral(expr)
@@ -137,8 +103,14 @@ end
137103
function model(mod, linenumbernode, expr, warn)
138104
modelinfo = build_model_info(expr)
139105

140-
# Generate main body
141-
modelinfo[:body] = generate_mainbody(mod, modelinfo[:modeldef][:body], warn)
106+
# Generate main body and find all variable symbols
107+
modelinfo[:body], modelinfo[:varnames] = generate_mainbody(
108+
mod, modelinfo[:modeldef][:body], warn
109+
)
110+
111+
# extract observations from that
112+
modelinfo[:obsnames] = modelinfo[:allargs_syms] modelinfo[:varnames]
113+
modelinfo[:latentnames] = setdiff(modelinfo[:varnames], modelinfo[:allargs_syms])
142114

143115
return build_output(modelinfo, linenumbernode)
144116
end
@@ -167,8 +139,7 @@ function build_model_info(input_expr)
167139
modelinfo = Dict(
168140
:allargs_exprs => [],
169141
:allargs_syms => [],
170-
:allargs_namedtuple => NamedTuple(),
171-
:defaults_namedtuple => NamedTuple(),
142+
:allargs_defaults => [],
172143
:modeldef => modeldef,
173144
)
174145
return modelinfo
@@ -177,17 +148,18 @@ function build_model_info(input_expr)
177148
# Extract the positional and keyword arguments from the model definition.
178149
allargs = vcat(modeldef[:args], modeldef[:kwargs])
179150

180-
# Split the argument expressions and the default values.
181-
allargs_exprs_defaults = map(allargs) do arg
182-
MacroTools.@match arg begin
151+
# Split the argument expressions and the default values, by unzipping allargs, taking care of
152+
# the empty case
153+
allargs_exprs, allargs_defaults = foldl(allargs; init=([], [])) do (ae, ad), arg
154+
(expr, default) = MacroTools.@match arg begin
183155
(x_ = val_) => (x, val)
184156
x_ => (x, NO_DEFAULT)
185157
end
158+
push!(ae, expr)
159+
push!(ad, default)
160+
ae, ad
186161
end
187-
188-
# Extract the expressions of the arguments, without default values.
189-
allargs_exprs = first.(allargs_exprs_defaults)
190-
162+
191163
# Extract the names of the arguments.
192164
allargs_syms = map(allargs_exprs) do arg
193165
MacroTools.@match arg begin
@@ -196,28 +168,11 @@ function build_model_info(input_expr)
196168
x_ => x
197169
end
198170
end
199-
200-
# Build named tuple expression of the argument symbols and variables of the same name.
201-
allargs_namedtuple = to_namedtuple_expr(allargs_syms)
202-
203-
# Extract default values of the positional and keyword arguments.
204-
default_syms = []
205-
default_vals = []
206-
for (sym, (expr, val)) in zip(allargs_syms, allargs_exprs_defaults)
207-
if val !== NO_DEFAULT
208-
push!(default_syms, sym)
209-
push!(default_vals, val)
210-
end
211-
end
212-
213-
# Build named tuple expression of the argument symbols with default values.
214-
defaults_namedtuple = to_namedtuple_expr(default_syms, default_vals)
215-
171+
216172
modelinfo = Dict(
217173
:allargs_exprs => allargs_exprs,
218174
:allargs_syms => allargs_syms,
219-
:allargs_namedtuple => allargs_namedtuple,
220-
:defaults_namedtuple => defaults_namedtuple,
175+
:allargs_defaults => allargs_defaults,
221176
:modeldef => modeldef,
222177
)
223178

@@ -233,43 +188,50 @@ Generate the body of the main evaluation function from expression `expr` and arg
233188
If `warn` is true, a warning is displayed if internal variables are used in the model
234189
definition.
235190
"""
236-
generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, warn)
191+
function generate_mainbody(mod, expr, warn)
192+
varnames = Symbol[]
193+
body = generate_mainbody!(mod, Symbol[], varnames, expr, warn)
194+
return body, varnames
195+
end
237196

238-
generate_mainbody!(mod, found, x, warn) = x
239-
function generate_mainbody!(mod, found, sym::Symbol, warn)
197+
generate_mainbody!(mod, found_internals, varnames, x, warn) = x
198+
function generate_mainbody!(mod, found_internals, sym::Symbol, warn)
240199
if sym in DEPRECATED_INTERNALNAMES
241200
newsym = Symbol(:_, sym, :__)
242201
Base.depwarn(
243202
"internal variable `$sym` is deprecated, use `$newsym` instead.",
244203
:generate_mainbody!,
245204
)
246-
return generate_mainbody!(mod, found, newsym, warn)
205+
return generate_mainbody!(mod, found_internals, newsym, warn)
247206
end
248207

249-
if warn && sym in INTERNALNAMES && sym found
208+
if warn && sym in INTERNALNAMES && sym found_internals
250209
@warn "you are using the internal variable `$sym`"
251-
push!(found, sym)
210+
push!(found_internals, sym)
252211
end
253212

254213
return sym
255214
end
256-
function generate_mainbody!(mod, found, expr::Expr, warn)
215+
function generate_mainbody!(mod, found_internals, varnames, expr::Expr, warn)
257216
# Do not touch interpolated expressions
258217
expr.head === :$ && return expr.args[1]
259218

260219
# If it's a macro, we expand it
261220
if Meta.isexpr(expr, :macrocall)
262-
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
221+
return generate_mainbody!(
222+
mod, found_internals, varnames, macroexpand(mod, expr; recursive=true), warn
223+
)
263224
end
264225

265226
# Modify dotted tilde operators.
266227
args_dottilde = getargs_dottilde(expr)
267228
if args_dottilde !== nothing
268229
L, R = args_dottilde
230+
!isliteral(L) && push!(varnames, vsym(L))
269231
return Base.remove_linenums!(
270232
generate_dot_tilde(
271-
generate_mainbody!(mod, found, L, warn),
272-
generate_mainbody!(mod, found, R, warn),
233+
generate_mainbody!(mod, found_internals, varnames, L, warn),
234+
generate_mainbody!(mod, found_internals, varnames, R, warn),
273235
),
274236
)
275237
end
@@ -278,15 +240,19 @@ function generate_mainbody!(mod, found, expr::Expr, warn)
278240
args_tilde = getargs_tilde(expr)
279241
if args_tilde !== nothing
280242
L, R = args_tilde
243+
!isliteral(L) && push!(varnames, vsym(L))
281244
return Base.remove_linenums!(
282245
generate_tilde(
283-
generate_mainbody!(mod, found, L, warn),
284-
generate_mainbody!(mod, found, R, warn),
246+
generate_mainbody!(mod, found_internals, varnames, L, warn),
247+
generate_mainbody!(mod, found_internals, varnames, R, warn),
285248
),
286249
)
287250
end
288251

289-
return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...)
252+
return Expr(
253+
expr.head,
254+
map(x -> generate_mainbody!(mod, found_internals, varnames, x, warn), expr.args)...,
255+
)
290256
end
291257

292258
"""
@@ -307,26 +273,26 @@ function generate_tilde(left, right)
307273

308274
# Otherwise it is determined by the model or its value,
309275
# if the LHS represents an observation
310-
@gensym vn inds isassumption
276+
@gensym vn inds isobservation
311277
return quote
312278
$vn = $(varname(left))
313279
$inds = $(vinds(left))
314-
$isassumption = $(DynamicPPL.isassumption(left))
315-
if $isassumption
316-
$left = $(DynamicPPL.tilde_assume!)(
280+
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
281+
if $isobservation
282+
$(DynamicPPL.tilde_observe!)(
317283
__context__,
318-
$(DynamicPPL.unwrap_right_vn)(
319-
$(DynamicPPL.check_tilde_rhs)($right), $vn
320-
)...,
284+
$(DynamicPPL.check_tilde_rhs)($right),
285+
$left,
286+
$vn,
321287
$inds,
322288
__varinfo__,
323289
)
324290
else
325-
$(DynamicPPL.tilde_observe!)(
291+
$left = $(DynamicPPL.tilde_assume!)(
326292
__context__,
327-
$(DynamicPPL.check_tilde_rhs)($right),
328-
$left,
329-
$vn,
293+
$(DynamicPPL.unwrap_right_vn)(
294+
$(DynamicPPL.check_tilde_rhs)($right), $vn
295+
)...,
330296
$inds,
331297
__varinfo__,
332298
)
@@ -351,26 +317,26 @@ function generate_dot_tilde(left, right)
351317

352318
# Otherwise it is determined by the model or its value,
353319
# if the LHS represents an observation
354-
@gensym vn inds isassumption
320+
@gensym vn inds isobservation
355321
return quote
356322
$vn = $(varname(left))
357323
$inds = $(vinds(left))
358-
$isassumption = $(DynamicPPL.isassumption(left))
359-
if $isassumption
360-
$left .= $(DynamicPPL.dot_tilde_assume!)(
324+
$isobservation = $(DynamicPPL.isobservation)($vn, __model__)
325+
if $isobservation
326+
$(DynamicPPL.dot_tilde_observe!)(
361327
__context__,
362-
$(DynamicPPL.unwrap_right_left_vns)(
363-
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
364-
)...,
328+
$(DynamicPPL.check_tilde_rhs)($right),
329+
$left,
330+
$vn,
365331
$inds,
366332
__varinfo__,
367333
)
368334
else
369-
$(DynamicPPL.dot_tilde_observe!)(
335+
$left .= $(DynamicPPL.dot_tilde_assume!)(
370336
__context__,
371-
$(DynamicPPL.check_tilde_rhs)($right),
372-
$left,
373-
$vn,
337+
$(DynamicPPL.unwrap_right_left_vns)(
338+
$(DynamicPPL.check_tilde_rhs)($right), $left, $vn
339+
)...,
374340
$inds,
375341
__varinfo__,
376342
)
@@ -413,10 +379,24 @@ function build_output(modelinfo, linenumbernode)
413379

414380
## Build the model function.
415381

416-
# Extract the named tuple expression of all arguments and the default values.
417-
allargs_namedtuple = modelinfo[:allargs_namedtuple]
418-
defaults_namedtuple = modelinfo[:defaults_namedtuple]
382+
# Extract the named tuple expression of all arguments
383+
allargs_newnames = [gensym(x) for x in modelinfo[:allargs_syms]]
384+
allargs_wrapped = map(modelinfo[:allargs_syms], modelinfo[:allargs_defaults]) do x, d
385+
if x modelinfo[:obsnames]
386+
:($(DynamicPPL.Variable)($x, $d))
387+
else
388+
:($(DynamicPPL.Constant)($x, $d))
389+
end
390+
end
391+
allargs_decls = [:($name = $val) for (name, val) in zip(allargs_newnames, allargs_wrapped)]
392+
allargs_namedtuple = to_namedtuple_expr(modelinfo[:allargs_syms], allargs_newnames)
419393

394+
internals_newnames = [gensym(x) for x in modelinfo[:latentnames]]
395+
internals_decls = map(internals_newnames) do name
396+
:($name = $(DynamicPPL.Variable)(missing))
397+
end
398+
internals_namedtuple = to_namedtuple_expr(modelinfo[:latentnames], internals_newnames)
399+
420400
# Update the function body of the user-specified model.
421401
# We use a name for the anonymous evaluator that does not conflict with other variables.
422402
modeldef = modelinfo[:modeldef]
@@ -427,11 +407,13 @@ function build_output(modelinfo, linenumbernode)
427407
modeldef[:body] = MacroTools.@q begin
428408
$(linenumbernode)
429409
$evaluator = $(MacroTools.combinedef(evaluatordef))
410+
$(allargs_decls...)
411+
$(internals_decls...)
430412
return $(DynamicPPL.Model)(
431413
$(QuoteNode(modeldef[:name])),
432414
$evaluator,
433415
$allargs_namedtuple,
434-
$defaults_namedtuple,
416+
$internals_namedtuple,
435417
)
436418
end
437419

src/context_implementations.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
1414
require_gradient(spl::Sampler) = false
1515
require_particles(spl::Sampler) = false
1616

17-
_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds))
18-
_getindex(x, inds::Tuple{}) = x
1917

2018
# assume
2119
"""

0 commit comments

Comments
 (0)