Skip to content
126 changes: 61 additions & 65 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)

for name in INTERNALNAMES
@eval const $(Symbol(uppercase(string(name)))) = $(Meta.quot(name))
end


"""
isassumption(expr)
isassumption(expr, vn)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.
Expand All @@ -14,38 +19,37 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

function isassumption(expr::Union{Symbol,Expr}, vn)
return quote
let $vn = $(AbstractPPL.drop_escape(varname(expr)))
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
$(maybe_view(expr)) === missing
end
if $(DynamicPPL.contextual_isassumption)($__CONTEXT__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, $__MODEL__)) ||
$(DynamicPPL.inmissings)($vn, $__MODEL__)
true
else
false
$(maybe_view(expr)) === missing
end
else
false
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr, vn) = :(false)

"""
contextual_isassumption(context, vn)

Expand Down Expand Up @@ -79,9 +83,6 @@ function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :(@views($x))
Expand Down Expand Up @@ -314,15 +315,13 @@ function generate_mainbody!(mod, found, sym::Symbol, warn)
return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Do we don't want escaped expressions because we unfortunately
# escape the entire body afterwards.
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
if Meta.isexpr(expr, :$)
# Do not touch interpolated expressions
return expr.args[1]
elseif Meta.isexpr(expr, :escape)
return generate_mainbody(mod, found, expr.args[1], warn)
elseif Meta.isexpr(expr, :macrocall)
# If it's a macro, we expand it (recursively)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
end

Expand Down Expand Up @@ -357,7 +356,7 @@ function generate_tilde_literal(left, right)
# If the LHS is a literal, it is always an observation
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
$__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $left, $__VARINFO__
)
end
end
Expand All @@ -375,26 +374,23 @@ function generate_tilde(left, right)
# if the LHS represents an observation
@gensym vn isassumption

# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
# that in DynamicPPL we the entire function body. Instead we should be
# more selective with our escape. Until that's the case, we remove them all.
return quote
$vn = $(AbstractPPL.drop_escape(varname(left)))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$(generate_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)($vn, $__MODEL__)
$left = $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn)
end

$(DynamicPPL.tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
$__VARINFO__,
)
end
end
Expand All @@ -403,14 +399,14 @@ end
function generate_tilde_assume(left, right, vn)
expr = :(
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
$__VARINFO__,
)
)

return if left isa Expr
AbstractPPL.drop_escape(
if left isa Expr
return AbstractPPL.drop_escape(
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
)
else
Expand All @@ -431,21 +427,21 @@ function generate_dot_tilde(left, right)
@gensym vn isassumption
return quote
$vn = $(AbstractPPL.drop_escape(varname(left)))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)($vn, $__MODEL__)
$left .= $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn)
end

$(DynamicPPL.dot_tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
$__VARINFO__,
)
end
end
Expand All @@ -457,11 +453,11 @@ function generate_dot_tilde_assume(left, right, vn)
# be something that supports `.=`.
return :(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
$__VARINFO__,
)
)
end
Expand All @@ -479,15 +475,14 @@ Builds the output expression.
function build_output(modelinfo, linenumbernode)
## Build the anonymous evaluator from the user-provided model definition.
evaluatordef = deepcopy(modelinfo[:modeldef])
original_arguments = modelinfo[:allargs_exprs]

# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__context__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
:($__MODEL__::$(DynamicPPL.Model)),
:($__VARINFO__::$(DynamicPPL.AbstractVarInfo)),
:($__CONTEXT__::$(DynamicPPL.AbstractContext)),
original_arguments,
)

# Delete the keyword arguments.
Expand All @@ -513,10 +508,11 @@ function build_output(modelinfo, linenumbernode)
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef = modelinfo[:modeldef]
modelname_symbol = Meta.quot(modeldef[:name])
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$modelname_symbol,
$(modeldef[:name]),
$allargs_namedtuple,
$defaults_namedtuple,
Expand Down