Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 88 additions & 138 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ end
const DISTMSG = "Right-hand side of a ~ must be subtype of Distribution or a vector of " *
"Distributions."

const RESERVEDNAMES = (:_model, :_sampler, :_context, :_varinfo)

"""
isassumption(model, expr)

Expand Down Expand Up @@ -69,9 +71,57 @@ end

To generate a `Model`, call `model_generator(x_value)`.
"""
macro model(input_expr)
build_model_info(input_expr) |> replace_tilde! |> replace_vi! |>
replace_logpdf! |> replace_sampler! |> build_output
macro model(expr)
return esc(model(expr))
end

function model(expr)
modelinfo = build_model_info(expr)

ex = generate_main_body(modelinfo[:main_body], modelinfo[:args])
modelinfo[:main_body] = ex

return build_output(modelinfo)
end

generate_main_body(x, args) = generate_main_body(x, args, Symbol[])
generate_main_body(x, args, checked) = x
function generate_main_body(sym::Symbol, args, checked)
if sym in RESERVEDNAMES && sym ∉ checked
@warn "you are using the reserved name `$(sym)`"
push!(checked, sym)
end
return sym
end
function generate_main_body(expr::Expr, args, checked)
# do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Apply the `@.` macro first.
if Meta.isexpr(expr, :macrocall) && length(expr.args) > 1 &&
expr.args[1] === Symbol("@__dot__")
return generate_main_body(Base.Broadcast.__dot__(expr.args[end]), args, checked)
end

# Check dot tilde.
args_dottilde = getargs_dottilde(expr)
if args_dottilde !== nothing
L, R = args_dottilde
return Base.remove_linenums!(generate_dot_tilde(generate_main_body(L, args, checked),
generate_main_body(R, args, checked),
args))
end

# Check tilde.
args_tilde = getargs_tilde(expr)
if args_tilde !== nothing
L, R = args_tilde
return Base.remove_linenums!(generate_tilde(generate_main_body(L, args, checked),
generate_main_body(R, args, checked),
args))
end

return Expr(expr.head, map(x -> generate_main_body(x, args, checked), expr.args)...)
end

"""
Expand Down Expand Up @@ -158,76 +208,12 @@ function build_model_info(input_expr)
:args_nt => args_nt,
:defaults_nt => defaults_nt,
:args => args,
:whereparams => modeldef[:whereparams],
:main_body_names => Dict(
:ctx => gensym(:ctx),
:vi => gensym(:vi),
:sampler => gensym(:sampler),
:model => gensym(:model)
)
:whereparams => modeldef[:whereparams]
)

return model_info
end


"""
replace_vi!(model_info)

Replaces `@varinfo()` expressions with a handle to the `VarInfo` struct.
"""
function replace_vi!(model_info)
ex = model_info[:main_body]
vi = model_info[:main_body_names][:vi]
ex = MacroTools.postwalk(ex) do x
if @capture(x, @varinfo())
vi
else
x
end
end
model_info[:main_body] = ex
return model_info
end

"""
replace_logpdf!(model_info)

Replaces `@logpdf()` expressions with the value of the accumulated `logpdf` in the `VarInfo` struct.
"""
function replace_logpdf!(model_info)
ex = model_info[:main_body]
vi = model_info[:main_body_names][:vi]
ex = MacroTools.postwalk(ex) do x
if @capture(x, @logpdf())
:($(vi).logp[])
else
x
end
end
model_info[:main_body] = ex
return model_info
end

"""
replace_sampler!(model_info)

Replaces `@sampler()` expressions with a handle to the sampler struct.
"""
function replace_sampler!(model_info)
ex = model_info[:main_body]
spl = model_info[:main_body_names][:sampler]
ex = MacroTools.postwalk(ex) do x
if @capture(x, @sampler())
spl
else
x
end
end
model_info[:main_body] = ex
return model_info
end

"""
replace_tilde!(model_info)

Expand All @@ -238,24 +224,6 @@ function replace_tilde!(model_info)
expr = model_info[:main_body]
dottedexpr = MacroTools.postwalk(apply_dotted, expr)

# Check for tilde operators.
tildeexpr = MacroTools.postwalk(dottedexpr) do x
# Check dot tilde first.
dotargs = getargs_dottilde(x)
if dotargs !== nothing
L, R = dotargs
return Base.remove_linenums!(generate_dot_tilde(L, R, model_info))
end

# Check tilde.
args = getargs_tilde(x)
if args !== nothing
L, R = args
return Base.remove_linenums!(generate_tilde(L, R, model_info))
end

return x
end

# Update the function body.
model_info[:main_body] = tildeexpr
Expand All @@ -267,17 +235,12 @@ end


"""
generate_tilde(left, right, model_info)
generate_tilde(left, right, args)

The `tilde` function generates `observe` expression for data variables and `assume`
expressions for parameter variables, updating `model_info` in the process.
Generate `observe` expressions for data variables and `assume` expressions for parameter
variables for a model with the given `args`.
"""
function generate_tilde(left, right, model_info)
model = model_info[:main_body_names][:model]
vi = model_info[:main_body_names][:vi]
ctx = model_info[:main_body_names][:ctx]
sampler = model_info[:main_body_names][:sampler]

function generate_tilde(left, right, args)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
Expand All @@ -288,25 +251,25 @@ function generate_tilde(left, right, model_info)
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

assumption = [
:($out = $(DynamicPPL.tilde_assume)($ctx, $sampler, $tmpright, $vn, $inds,
$vi)),
:($out = $(DynamicPPL.tilde_assume)(_context, _sampler, $tmpright, $vn, $inds,
_varinfo)),
:($left = $out[1]),
:($(DynamicPPL.acclogp!)($vi, $out[2]))
:($(DynamicPPL.acclogp!)(_varinfo, $out[2]))
]

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in model_info[:args]
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(model, left))
$isassumption = $(DynamicPPL.isassumption(:_model, left))
if $isassumption
$(assumption...)
else
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vn,
$inds, $vi)
_varinfo,
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, $vn,
$inds, _varinfo)
)
end
end
Expand All @@ -322,25 +285,19 @@ function generate_tilde(left, right, model_info)
return quote
$(top...)
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
_varinfo,
$(DynamicPPL.tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
)
end
end

"""
generate_dot_tilde(left, right, model_info)
generate_dot_tilde(left, right, args)

This function returns the expression that replaces `left .~ right` in the model body. If
`preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block
will be run.
Generate broadcasted `observe` expressions for data variables and `assume` expressions for parameter
variables for a model with the given `args`.
"""
function generate_dot_tilde(left, right, model_info)
model = model_info[:main_body_names][:model]
vi = model_info[:main_body_names][:vi]
ctx = model_info[:main_body_names][:ctx]
sampler = model_info[:main_body_names][:sampler]

function generate_dot_tilde(left, right, args)
@gensym tmpright
top = [:($tmpright = $right),
:($tmpright isa Union{$Distribution,AbstractVector{<:$Distribution}}
Expand All @@ -351,25 +308,25 @@ function generate_dot_tilde(left, right, model_info)
push!(top, :($vn = $(varname(left))), :($inds = $(vinds(left))))

assumption = [
:($out = $(DynamicPPL.dot_tilde_assume)($ctx, $sampler, $tmpright, $left,
$vn, $inds, $vi)),
:($out = $(DynamicPPL.dot_tilde_assume)(_context, _sampler, $tmpright, $left,
$vn, $inds, _varinfo)),
:($left .= $out[1]),
:($(DynamicPPL.acclogp!)($vi, $out[2]))
:($(DynamicPPL.acclogp!)(_varinfo, $out[2]))
]

# It can only be an observation if the LHS is an argument of the model
if vsym(left) in model_info[:args]
if vsym(left) in args
@gensym isassumption
return quote
$(top...)
$isassumption = $(DynamicPPL.isassumption(model, left))
$isassumption = $(DynamicPPL.isassumption(:_model, left))
if $isassumption
$(assumption...)
else
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left,
$vn, $inds, $vi)
_varinfo,
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left,
$vn, $inds, _varinfo)
)
end
end
Expand All @@ -385,8 +342,8 @@ function generate_dot_tilde(left, right, model_info)
return quote
$(top...)
$(DynamicPPL.acclogp!)(
$vi,
$(DynamicPPL.dot_tilde_observe)($ctx, $sampler, $tmpright, $left, $vi)
_varinfo,
$(DynamicPPL.dot_tilde_observe)(_context, _sampler, $tmpright, $left, _varinfo)
)
end
end
Expand All @@ -402,13 +359,6 @@ hasmissing(T::Type) = false
Builds the output expression.
"""
function build_output(model_info)
# Construct user-facing function
main_body_names = model_info[:main_body_names]
ctx = main_body_names[:ctx]
vi = main_body_names[:vi]
model = main_body_names[:model]
sampler = main_body_names[:sampler]

# Arguments with default values
args = model_info[:args]
# Argument symbols without default values
Expand All @@ -428,7 +378,7 @@ function build_output(model_info)
unwrap_data_expr = Expr(:block)
for var in arg_syms
push!(unwrap_data_expr.args,
:($var = $(DynamicPPL.matchingvalue)($sampler, $vi, $(model).args.$var)))
:($var = $(DynamicPPL.matchingvalue)(_sampler, _varinfo, _model.args.$var)))
end

@gensym(evaluator, generator)
Expand All @@ -437,13 +387,13 @@ function build_output(model_info)

ex = quote
function $evaluator(
$model::$(DynamicPPL.Model),
$vi::$(DynamicPPL.VarInfo),
$sampler::$(DynamicPPL.AbstractSampler),
$ctx::$(DynamicPPL.AbstractContext),
_model::$(DynamicPPL.Model),
_varinfo::$(DynamicPPL.VarInfo),
_sampler::$(DynamicPPL.AbstractSampler),
_context::$(DynamicPPL.AbstractContext),
)
$unwrap_data_expr
$(DynamicPPL.resetlogp!)($vi)
$(DynamicPPL.resetlogp!)(_varinfo)
$main_body
end

Expand Down