Skip to content
Merged
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -42,4 +44,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[targets]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs"]
23 changes: 16 additions & 7 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
module DynamicPPL

using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using Distributions: UnivariateDistribution,
MultivariateDistribution,
MatrixDistribution,
Distribution
using Bijectors: link, invlink
using Distributions
using Bijectors
using MacroTools

import Base: string,
Expand Down Expand Up @@ -76,21 +73,33 @@ export VarName,
LikelihoodContext,
PriorContext,
MiniBatchContext,
assume,
dot_assume,
observer,
dot_observe,
tilde,
dot_tilde,
# Pseudo distributions
NamedDist,
NoDist,
# Prob macros
@prob_str,
@logprob_str

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_DYNAMICPPL", "0")))

# Used here and overloaded in Turing
function getspace end
function tilde end
function dot_tilde end

include("utils.jl")
include("selector.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("prob_macro.jl")

Expand Down
125 changes: 8 additions & 117 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,116 +33,7 @@ function _error_msg()
return "This macro is only for use in the `@model` macro and not for external use."
end

"""
@varname(var)

A macro that returns an instance of `VarName` given the symbol or expression of a Julia variable, e.g. `@varname x[1,2][1+5][45][3]` returns `VarName{:x}("[1,2][6][45][3]")`.
"""
macro varname(expr::Union{Expr, Symbol})
expr |> varname |> esc
end
function varname(expr)
ex = deepcopy(expr)
(ex isa Symbol) && return quote
DynamicPPL.VarName{$(QuoteNode(ex))}("")
end
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
inds = :(())
while ex.head == :ref
if length(ex.args) >= 2
strs = map(x -> :($x === (:) ? "Colon()" : string($x)), ex.args[2:end])
pushfirst!(inds.args, :("[" * join($(Expr(:vect, strs...)), ",") * "]"))
end
ex = ex.args[1]
isa(ex, Symbol) && return quote
DynamicPPL.VarName{$(QuoteNode(ex))}(foldl(*, $inds, init = ""))
end
end
throw("VarName: Mis-formed variable name $(expr)!")
end

macro vsym(expr::Union{Expr, Symbol})
expr |> vsym
end

"""
vsym(expr::Union{Expr, Symbol})

Returns the variable symbol given the input variable expression `expr`. For example, if the input `expr = :(x[1])`, the output is `:x`.
"""
function vsym(expr::Union{Expr, Symbol})
ex = deepcopy(expr)
(ex isa Symbol) && return QuoteNode(ex)
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
while ex.head == :ref
ex = ex.args[1]
isa(ex, Symbol) && return QuoteNode(ex)
end
throw("VarName: Mis-formed variable name $(expr)!")
end

"""
@vinds(expr)

Returns a tuple of tuples of the indices in `expr`. For example, `@vinds x[1,:][2]` returns
`((1, Colon()), (2,))`.
"""
macro vinds(expr::Union{Expr, Symbol})
expr |> vinds |> esc
end
function vinds(expr::Union{Expr, Symbol})
ex = deepcopy(expr)
inds = Expr(:tuple)
(ex isa Symbol) && return inds
(ex.head == :ref) || throw("VarName: Mis-formed variable name $(expr)!")
while ex.head == :ref
pushfirst!(inds.args, Expr(:tuple, ex.args[2:end]...))
ex = ex.args[1]
isa(ex, Symbol) && return inds
end
throw("VarName: Mis-formed variable name $(expr)!")
end

"""
split_var_str(var_str, inds_as = Vector)

This function splits a variable string, e.g. `"x[1:3,1:2][3,2]"` to the variable's symbol `"x"` and the indexing `"[1:3,1:2][3,2]"`. If `inds_as = String`, the indices are returned as a string, e.g. `"[1:3,1:2][3,2]"`. If `inds_as = Vector`, the indices are returned as a vector of vectors of strings, e.g. `[["1:3", "1:2"], ["3", "2"]]`.
"""
function split_var_str(var_str, inds_as = Vector)
ind = findfirst(c -> c == '[', var_str)
if inds_as === String
if ind === nothing
return var_str, ""
else
return var_str[1:ind-1], var_str[ind:end]
end
end
@assert inds_as === Vector
inds = Vector{String}[]
if ind === nothing
return var_str, inds
end
sym = var_str[1:ind-1]
ind = length(sym)
while ind < length(var_str)
ind += 1
@assert var_str[ind] == '['
push!(inds, String[])
while var_str[ind] != ']'
ind += 1
if var_str[ind] == '['
ind2 = findnext(c -> c == ']', var_str, ind)
push!(inds[end], strip(var_str[ind:ind2]))
ind = ind2+1
else
ind2 = findnext(c -> c == ',' || c == ']', var_str, ind)
push!(inds[end], strip(var_str[ind:ind2-1]))
ind = ind2
end
end
end
return sym, inds
end

# Check if the right-hand side is a distribution.
function assert_dist(dist; msg)
Expand Down Expand Up @@ -404,21 +295,21 @@ function replace_tilde!(model_info)
ex = model_info[:main_body]
ex = MacroTools.postwalk(ex) do x
if @capture(x, @M_ L_ ~ R_) && M == Symbol("@__dot__")
dot_tilde(L, R, model_info)
generate_dot_tilde(L, R, model_info)
else
x
end
end
$(VERSION >= v"1.1" ? "ex = MacroTools.postwalk(ex) do x
if @capture(x, L_ .~ R_)
dot_tilde(L, R, model_info)
generate_dot_tilde(L, R, model_info)
else
x
end
end" : "")
ex = MacroTools.postwalk(ex) do x
if @capture(x, L_ ~ R_)
tilde(L, R, model_info)
generate_tilde(L, R, model_info)
else
x
end
Expand All @@ -429,12 +320,12 @@ end
""" |> Meta.parse |> eval

"""
tilde(left, right, model_info)
generate_tilde(left, right, model_info)

The `tilde` function generates `observe` expression for data variables and `assume`
expressions for parameter variables, updating `model_info` in the process.
"""
function tilde(left, right, model_info)
function generate_tilde(left, right, model_info)
arg_syms = Val((model_info[:arg_syms]...,))
model = model_info[:main_body_names][:model]
vi = model_info[:main_body_names][:vi]
Expand Down Expand Up @@ -478,11 +369,11 @@ function tilde(left, right, model_info)
end

"""
dot_tilde(left, right, model_info)
generate_dot_tilde(left, right, model_info)

This function returns the expression that replaces `left .~ right` in the model body. If `preprocessed isa VarName`, then a `dot_assume` block will be run. Otherwise, a `dot_observe` block will be run.
"""
function dot_tilde(left, right, model_info)
function generate_dot_tilde(left, right, model_info)
arg_syms = Val((model_info[:arg_syms]...,))
model = model_info[:main_body_names][:model]
vi = model_info[:main_body_names][:vi]
Expand Down Expand Up @@ -636,4 +527,4 @@ end
Get the specialized version of type `T` for sampler `spl`. For example,
if `T === Float64` and `spl::Hamiltonian`, the matching type is `eltype(vi[spl])`.
"""
function get_matching_type end
function get_matching_type end
Loading