Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ export AbstractVarInfo,
ModelGen,
@model,
@varname,
@varinfo,
@logpdf,
@sampler,
# Utilities
vectorize,
reconstruct,
Expand Down
171 changes: 66 additions & 105 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,30 @@
macro varinfo()
:(throw(_error_msg()))
end
macro logpdf()
:(throw(_error_msg()))
end
macro sampler()
:(throw(_error_msg()))
end
function _error_msg()
return "This macro is only for use in the `@model` macro and not for external use."
end

const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
"Distributions."

const INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo)

"""
isassumption(model, expr)
isassumption(expr)

Return an expression that can be evaluated to check if `expr` is an assumption in the
`model`.
model.

Let `expr` be `:(x[1])`. It is an assumption in the following cases:
1. `x` is not among the input data to the `model`,
2. `x` is among the input data to the `model` but with a value `missing`, or
3. `x` is among the input data to the `model` with a value other than missing,
1. `x` is not among the input data to the model,
2. `x` is among the input data to the model but with a value `missing`, or
3. `x` is among the input data to the model with a value other than missing,
but `x[1] === missing`.

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(model, expr::Union{Symbol, Expr})
function isassumption(expr::Union{Symbol, Expr})
vn = gensym(:vn)

return quote
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, $model) || $(DynamicPPL.inmissings)($vn, $model)
if !$(DynamicPPL.inargnames)($vn, _model) || $(DynamicPPL.inmissings)($vn, _model)
true
else
# Evaluate the LHS
Expand All @@ -46,7 +35,7 @@ function isassumption(model, expr::Union{Symbol, Expr})
end

# failsafe: a literal is never an assumption
isassumption(model, expr) = :(false)
isassumption(expr) = :(false)

#################
# Main Compiler #
Expand Down Expand Up @@ -77,7 +66,7 @@ function model(expr)
modelinfo = build_model_info(expr)

# Generate main body
modelinfo[:main_body] = generate_mainbody(modelinfo)
modelinfo[:main_body] = generate_mainbody(modelinfo[:main_body], modelinfo[:args])

return build_output(modelinfo)
end
Expand Down Expand Up @@ -166,67 +155,57 @@ function build_model_info(input_expr)
:args_nt => args_nt,
:defaults_nt => defaults_nt,
:args => args,
:whereparams => modeldef[:whereparams],
:main_body_names => Dict(
:ctx => gensym(:ctx),
:vi => gensym(:vi),
:sampler => gensym(:sampler),
:model => gensym(:model)
)
:whereparams => modeldef[:whereparams]
)

return model_info
end

"""
generate_mainbody([expr, ]modelinfo)
generate_mainbody(expr, args)

Generate the body of the main evaluation function.
Generate the body of the main evaluation function from expression `expr` and arguments
`args`.
"""
generate_mainbody(modelinfo) = generate_mainbody(modelinfo[:main_body], modelinfo)
generate_mainbody(expr, args) = generate_mainbody!(Symbol[], expr, args)

generate_mainbody(x, modelinfo) = x
function generate_mainbody(expr::Expr, modelinfo)
generate_mainbody!(found, x, args) = x
function generate_mainbody!(found, sym::Symbol, args)
if sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$(sym)`"
push!(found, sym)
end
return sym
end
function generate_mainbody!(found, expr::Expr, args)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Apply the `@.` macro first.
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
expr.args[1] === Symbol("@__dot__")
return generate_mainbody(Base.Broadcast.__dot__(expr.args[end]), modelinfo)
end

# Modify macro calls.
if Meta.isexpr(expr, :macrocall) && !isempty(expr.args)
name = expr.args[1]
if name === Symbol("@varinfo")
return modelinfo[:main_body_names][:vi]
elseif name === Symbol("@logpdf")
return :($(modelinfo[:main_body_names][:vi]).logp[])
elseif name === Symbol("@sampler")
return :($(modelinfo[:main_body_names][:sampler]))
end
return generate_mainbody!(found, Base.Broadcast.__dot__(expr.args[end]), args)
end

# Modify dotted tilde operators.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody(L, modelinfo),
generate_mainbody(R, modelinfo),
modelinfo))
return Base.remove_linenums!(generate_dot_tilde(generate_mainbody!(found, L, args),
generate_mainbody!(found, R, args),
args))
end

# Modify tilde operators.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
return Base.remove_linenums!(generate_tilde(generate_mainbody(L, modelinfo),
generate_mainbody(R, modelinfo),
modelinfo))
return Base.remove_linenums!(generate_tilde(generate_mainbody!(found, L, args),
generate_mainbody!(found, R, args),
args))
end

return Expr(expr.head, map(x -> generate_mainbody(x, modelinfo), expr.args)...)
return Expr(expr.head, map(x -> generate_mainbody!(found, x, args), expr.args)...)
end

"""
Expand Down Expand Up @@ -268,17 +247,12 @@ end


"""
generate_tilde(left, right, model_info)
generate_tilde(left, right, args)

The `tilde` function generates `observe` expression for data variables and `assume`
expressions for parameter variables, updating `model_info` in the process.
Generate an `observe` expression for data variables and `assume` expression for parameter
variables.
"""
function generate_tilde(left, right, model_info)
model = model_info[:main_body_names][:model]
vi = model_info[:main_body_names][:vi]
ctx = model_info[:main_body_names][:ctx]
sampler = model_info[:main_body_names][:sampler]

function generate_tilde(left, right, args)
@gensym tmpright tmpleft
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
Expand All @@ -289,26 +263,26 @@ function generate_tilde(left, right, model_info)
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

assumption = [
:($out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
$vi)),
:($(DynamicPPL.acclogp!)($vi, $out[2])),
:($out = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
_varinfo)),
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
:($left = $out[1])
]

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in model_info[:args]
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(model, left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$(assumption...)
else
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vn,
$inds, $vi)
_varinfo,
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
$vn, $inds, _varinfo)
)
$tmpleft
end
Expand All @@ -326,26 +300,19 @@ function generate_tilde(left, right, model_info)
$(top...)
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
_varinfo,
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $tmpleft, _varinfo)
)
$tmpleft
end
end

"""
generate_dot_tilde(left, right, model_info)
generate_dot_tilde(left, right, args)

This function returns the expression that replaces `left .~ right` in the model body. If
`preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
will be run.
Generate the expression that replaces `left .~ right` in the model body.
"""
function generate_dot_tilde(left, right, model_info)
model = model_info[:main_body_names][:model]
vi = model_info[:main_body_names][:vi]
ctx = model_info[:main_body_names][:ctx]
sampler = model_info[:main_body_names][:sampler]

function generate_dot_tilde(left, right, args)
@gensym tmpright tmpleft
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
Expand All @@ -356,26 +323,26 @@ function generate_dot_tilde(left, right, model_info)
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

assumption = [
:($out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
$vn, $inds, $vi)),
:($(DynamicPPL.acclogp!)($vi, $out[2])),
:($out = $(DynamicPPL.dot_tilde_assume)(_context, _sampler, $tmpright, $left,
$vn, $inds, _varinfo)),
:($(DynamicPPL.acclogp!)(_varinfo, $out[2])),
:($left .= $out[1])
]

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in model_info[:args]
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(model, left))
$isassumption = $(DynamicPPL.isassumption(left))
if $isassumption
$(assumption...)
else
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft,
$vn, $inds, $vi)
_varinfo,
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright,
$tmpleft, $vn, $inds, _varinfo)
)
$tmpleft
end
Expand All @@ -393,8 +360,9 @@ function generate_dot_tilde(left, right, model_info)
$(top...)
$tmpleft = $left
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $tmpleft, $vi)
_varinfo,
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $tmpleft,
_varinfo)
)
$tmpleft
end
Expand All @@ -411,13 +379,6 @@ hasmissing(T::Type) = false
Builds the output expression.
"""
function build_output(model_info)
# Construct user-facing function
main_body_names = model_info[:main_body_names]
ctx = main_body_names[:ctx]
vi = main_body_names[:vi]
model = main_body_names[:model]
sampler = main_body_names[:sampler]

# Arguments with default values
args = model_info[:args]
# Argument symbols without default values
Expand All @@ -437,7 +398,7 @@ function build_output(model_info)
unwrap_data_expr = Expr(:block)
for var in arg_syms
push!(unwrap_data_expr.args,
:($var = $(DynamicPPL.matchingvalue)($sampler, $vi, $(model).args.$var)))
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
end

@gensym(evaluator, generator)
Expand All @@ -446,10 +407,10 @@ function build_output(model_info)

return quote
function $evaluator(
$model::$(DynamicPPL.Model),
$vi::$(DynamicPPL.VarInfo),
$sampler::$(DynamicPPL.AbstractSampler),
$ctx::$(DynamicPPL.AbstractContext),
_model::$(DynamicPPL.Model),
_varinfo::$(DynamicPPL.VarInfo),
_sampler::$(DynamicPPL.AbstractSampler),
_context::$(DynamicPPL.AbstractContext),
)
$unwrap_data_expr
$main_body
Expand Down
3 changes: 0 additions & 3 deletions test/Turing/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ end
# Turing essentials - modelling macros and inference algorithms
export @model, # modelling
@varname,
@varinfo,
@logpdf,
@sampler,
DynamicPPL,

MH, # classic sampling
Expand Down
3 changes: 0 additions & 3 deletions test/Turing/core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ export @model,
ADBACKEND,
setchunksize,
verifygrad,
@varinfo,
@logpdf,
@sampler,
@logprob_str,
@prob_str

Expand Down
Loading