Skip to content

Commit 63f03d2

Browse files
committed
fix many things and add logjoint
1 parent 16fec6e commit 63f03d2

File tree

3 files changed

+98
-47
lines changed

3 files changed

+98
-47
lines changed

src/Turing.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,27 @@ include("stdlib/distributions.jl")
3737
include("stdlib/RandomMeasures.jl")
3838

3939
"""
40-
struct Model{F, Targs <: NamedTuple}
40+
struct Model{F, Targs <: NamedTuple, Tmissings <: Val}
4141
f::F
4242
args::Targs
43+
missings::Tmissings
4344
end
4445
4546
A `Model` struct with arguments `args` and inner function `f`.
4647
"""
47-
struct Model{F, Targs <: NamedTuple} <: AbstractModel
48+
struct Model{F, Targs <: NamedTuple, Tmissings <: Val} <: AbstractModel
4849
f::F
4950
args::Targs
51+
missings::Tmissings
5052
end
53+
Model(f, args::NamedTuple) = Model(f, args, getmissing(args))
5154
(model::Model)(vi) = model(vi, SampleFromPrior())
5255
(model::Model)(vi, spl) = model(vi, spl, DefaultContext())
5356
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)
54-
getmissing(model::Model) = _getmissing(model.args)
55-
@generated function _getmissing(args::NamedTuple{names, ttuple}) where {names, ttuple}
57+
58+
getmissing(model::Model) = model.missings
59+
@generated function getmissing(args::NamedTuple{names, ttuple}) where {names, ttuple}
60+
length(names) == 0 && return :(Val{()}())
5661
minds = filter(1:length(names)) do i
5762
ttuple.types[i] == Missing
5863
end

src/core/prob_macro.jl

Lines changed: 82 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,39 +27,69 @@ end
2727

2828
function logprob(ex1, ex2)
2929
@assert isdefined(ex2, :model)
30-
ptype = probtype(ex2)
30+
modelgen = ex2.model
31+
ptype = probtype(ex1, ex2, modelgen, modelgen.defaults)
3132
vi = isdefined(ex2, :varinfo) ? ex2.varinfo : nothing
32-
model = ex2.model
3333
if ptype isa Val{:prior}
34-
return logprior(ex1, model, vi)
34+
return logprior(ex1, modelgen, vi)
3535
elseif ptype isa Val{:likelihood}
36-
return loglikelihood(ex1, ex2, model, vi)
36+
return loglikelihood(ex1, ex2, modelgen, vi)
3737
end
3838
end
3939

40-
@generated function probtype(
41-
ntr::NamedTuple{namesr},
42-
) where {namesr}
43-
if namesr == (:model,) || namesr == (:model, :varinfo)
44-
return :(Val(:prior))
40+
function probtype(
41+
ntl::NamedTuple{namesl},
42+
ntr::NamedTuple{namesr},
43+
modelgen::ModelGen{args},
44+
defaults::NamedTuple{defs},
45+
) where {namesl, namesr, args, defs}
46+
basic_namesr = namesr == (:model,) || namesr == (:model, :varinfo)
47+
@inline valid_arg(arg) = arg in namesl || arg in namesr || (arg in defs) &&
48+
!(getfield(defaults, arg) isa Missing)
49+
50+
valid_args = all(valid_arg.(args))
51+
# Uses the default values for model arguments not provided.
52+
# If no default value exists, use `nothing`.
53+
if basic_namesr
54+
return Val(:prior)
55+
# Uses the default values for model arguments not provided.
56+
# If no default value exists or the default value is missing, then error.
57+
elseif valid_args
58+
return Val(:likelihood)
4559
else
46-
return :(Val(:likelihood))
60+
for arg in args
61+
if !valid_arg(args)
62+
throw(ArgumentError(missing_arg_error_msg(arg)))
63+
end
64+
end
4765
end
4866
end
4967

68+
missing_arg_error_msg(arg) = """Variable $arg is not defined and has no default value, or its default value is `missing`. Please make sure all the variables are defined or have a default value other than `missing`."""
69+
5070
function logprior(
5171
left::NamedTuple,
5272
modelgen::ModelGen,
5373
_vi::Union{Nothing, VarInfo},
5474
)
55-
# Pass NaN to args which are not on the lhs of |
56-
# These will be ignored observe and dot_observe statements
57-
# Pass missing to args which are on the lhs of |
58-
# Their value is then assigned from left in PriorContext(left)
59-
# When all of args are on the lhs of |, this is also equal to the logjoint
60-
args = get_prior_model_args(left, modelgen, modelgen.defaults)
61-
model = modelgen(; args...)
62-
vi = _vi === nothing ? VarInfo(model, PriorContext()) : _vi
75+
# For model args on the LHS of |, use their passed value but add the symbol to
76+
# model.missings. This will lead to an `assume`/`dot_assume` call for those variables.
77+
# Let `p::PriorContext`. If `p.vars` is `nothing`, `assume` and `dot_assume` will use
78+
# the values of the random variables in the `VarInfo`. If `p.vars` is a `NamedTuple`
79+
# or a `Chain`, the value in `p.vars` is input into the `VarInfo` and used instead.
80+
81+
# For model args not on the LHS of |, if they have a default value, use that,
82+
# otherwise use `nothing`. This will lead to an `observe`/`dot_observe`call for
83+
# those variables.
84+
# All `observe` and `dot_observe` calls are no-op in the PriorContext
85+
86+
# When all of model args are on the lhs of |, this is also equal to the logjoint.
87+
args, missing_vars = get_prior_model_args(left, modelgen, modelgen.defaults)
88+
model = get_model(modelgen, args, missing_vars)
89+
vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi
90+
foreach(keys(vi.metadata)) do n
91+
@assert n in keys(left) "Variable $n is not defined."
92+
end
6393
model(vi, SampleFromPrior(), PriorContext(left))
6494
return vi.logp
6595
end
@@ -68,17 +98,26 @@ end
6898
modelgen::ModelGen{args},
6999
defaults::NamedTuple{default_args},
70100
) where {namesl, args, default_args}
71-
exprs = map(args) do arg
101+
exprs = []
102+
missing_args = []
103+
foreach(args) do arg
72104
if arg in namesl
73-
:($arg = missing)
74-
elseif arg in default_args
75-
:($arg = defaults.$arg)
105+
push!(exprs, :($arg = deepcopy(left.$arg)))
106+
push!(missing_args, arg)
107+
elseif arg in default_args
108+
push!(exprs, :($arg = defaults.$arg))
76109
else
77-
:($arg = NaN)
110+
push!(exprs, :($arg = nothing))
78111
end
79112
end
80-
length(exprs) == 0 && return :(NamedTuple())
81-
return :($(exprs...),)
113+
missing_vars = :(Val{($missing_args...,)}())
114+
length(exprs) == 0 && :(NamedTuple(), $missing_vars)
115+
return :(($(exprs...),), $missing_vars)
116+
end
117+
118+
function get_model(modelgen, args, missing_vars)
119+
_model = modelgen(; args...)
120+
return Turing.Model(_model.f, args, missing_vars)
82121
end
83122

84123
function loglikelihood(
@@ -88,9 +127,9 @@ function loglikelihood(
88127
_vi::Union{Nothing, VarInfo},
89128
)
90129
# Pass namesl to model constructor, remaining args are missing
91-
args = get_like_model_args(left, modelgen, modelgen.defaults)
92-
model = modelgen(; args...)
93-
vi = _vi === nothing ? VarInfo(model) : _vi
130+
args, missing_vars = get_like_model_args(left, right, modelgen, modelgen.defaults)
131+
model = get_model(modelgen, args, missing_vars)
132+
vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi
94133
if isdefined(right, :chain)
95134
# Element-wise likelihood for each value in chain
96135
ctx = LikelihoodContext()
@@ -110,20 +149,27 @@ function loglikelihood(
110149
end
111150
@generated function get_like_model_args(
112151
left::NamedTuple{namesl},
152+
right::NamedTuple{namesr},
113153
modelgen::ModelGen{args},
114154
defaults::NamedTuple{default_args},
115-
) where {namesl, args, default_args}
116-
exprs = map(args) do arg
155+
) where {namesl, namesr, args, default_args}
156+
exprs = []
157+
missing_args = []
158+
foreach(args) do arg
117159
if arg in namesl
118-
:($arg = left.$arg)
119-
elseif arg in default_args
120-
:($arg = defaults.$arg)
160+
push!(exprs, :($arg = left.$arg))
161+
elseif arg in namesr
162+
push!(exprs, :($arg = right.$arg))
163+
push!(missing_args, arg)
164+
elseif arg in default_args
165+
push!(exprs, :($arg = defaults.$arg))
121166
else
122-
:($arg = missing)
167+
throw("This point should not be reached. Please open an issue in the Turing.jl repository.")
123168
end
124169
end
125-
length(exprs) == 0 && return :(NamedTuple())
126-
return :($(exprs...),)
170+
missing_vars = :(Val{($missing_args...,)}())
171+
length(exprs) == 0 && :(NamedTuple(), $missing_vars)
172+
return :(($(exprs...),), $missing_vars)
127173
end
128174

129175
_setval!(vi::TypedVarInfo, c::Chains) = _setval!(vi.metadata, vi, c)

test/core/prob_macro.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
using Turing, Distributions, Test
22

3-
dir = splitdir(splitdir(pathof(Turing))[1])[1]
4-
include(dir*"/test/test_utils/AllUtils.jl")
5-
63
Random.seed!(129)
74

85
@turing_testset "logprob" begin
@@ -18,15 +15,17 @@ Random.seed!(129)
1815

1916
logprior = logpdf(Normal(), mval)
2017
loglike = logpdf(Normal(mval, 1), xval)
18+
logjoint = logprior + loglike
2119

2220
@test logprob"m = mval | model = demo" == logprior
2321
@test logprob"x = xval | m = mval, model = demo" == loglike
24-
#@test logprob"x = xval, m = mval | model = demo" == loglike
22+
@test logprob"x = xval, m = mval | model = demo" == logjoint
2523

2624
varinfo = Turing.VarInfo(demo(xval))
2725
@test logprob"m = mval | model = demo, varinfo = varinfo" == logprior
2826
@test logprob"x = xval | m = mval, model = demo, varinfo = varinfo" == loglike
29-
#@test logprob"x = xval, m = mval | model = demo" == loglike
27+
varinfo = Turing.VarInfo(demo(missing))
28+
@test logprob"x = xval, m = mval | model = demo, varinfo = varinfo" == logjoint
3029

3130
chain = sample(demo(xval), IS(), iters)
3231
lps = logpdf.(Normal.(vec(chain["m"].value), 1), xval)
@@ -47,15 +46,16 @@ Random.seed!(129)
4746
logprior = sum(logpdf.(Normal(), mval))
4847
like(m, x) = sum(logpdf.(Normal.(m, 1), x))
4948
loglike = like(mval, xval)
49+
logjoint = logprior + loglike
5050

5151
@test logprob"m = mval | model = demo" == logprior
5252
@test logprob"x = xval | m = mval, model = demo" == loglike
53-
#@test logprob"x = xval, m = mval | model = demo" == logprior + loglike
53+
@test logprob"x = xval, m = mval | model = demo" == logjoint
5454

5555
varinfo = Turing.VarInfo(demo(xval))
5656
@test logprob"m = mval | model = demo, varinfo = varinfo" == logprior
5757
@test logprob"x = xval | m = mval, model = demo, varinfo = varinfo" == loglike
58-
#@test logprob"x = xval, m = mval | model = demo" == loglike
58+
# Currently, we cannot easily pre-allocate `VarInfo` for vector data
5959

6060
chain = sample(demo(xval), HMC(0.5, 1), iters)
6161
lps = like.([[chain["m[$i]"].value[j] for i in 1:n] for j in 1:iters], Ref(xval))

0 commit comments

Comments
 (0)