Skip to content

Commit 2020741

Browse files
authored
Implement returned for AbstractDict; deprecate {values, keys} method (#1096)
* Implement returned for AbstractDict; deprecate {values, keys} method * Fix doctest * Add more tests (beyond the doctest) * Remove accs
1 parent 22740ed commit 2020741

File tree

5 files changed

+67
-28
lines changed

5 files changed

+67
-28
lines changed

HISTORY.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# DynamicPPL Changelog
22

3+
## 0.38.3
4+
5+
Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.
6+
Please note we generally recommend using Dict, as NamedTuples cannot correctly represent variables with indices / fields on the left-hand side of tildes, like `x[1]` or `x.a`.
7+
8+
The generic method `returned(::Model, values, keys)` is deprecated and will be removed in the next minor version.
9+
310
## 0.38.2
411

512
Added a compatibility entry for JET@0.11.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.38.2"
3+
version = "0.38.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ It is possible to manually increase (or decrease) the accumulated log likelihood
176176
@addlogprob!
177177
```
178178

179-
Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`.
179+
Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples), or a single sample represented as a `NamedTuple` or a dictionary of VarNames.
180180

181181
```@docs
182182
returned(::DynamicPPL.Model, ::MCMCChains.Chains)
183-
returned(::DynamicPPL.Model, ::NamedTuple)
183+
returned(::DynamicPPL.Model, ::Union{NamedTuple,AbstractDict{<:VarName}})
184184
```
185185

186186
For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using

src/model.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,44 +1103,44 @@ function predict end
11031103

11041104
"""
11051105
returned(model::Model, parameters::NamedTuple)
1106-
returned(model::Model, values, keys)
1107-
returned(model::Model, values, keys)
1106+
returned(model::Model, parameters::AbstractDict{<:VarName})
11081107
11091108
Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.
11101109
1111-
If a `NamedTuple` is given, `keys=keys(parameters)` and `values=values(parameters)`.
1110+
returned(model::Model, values, keys)
1111+
1112+
Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.
1113+
This method is deprecated; use the NamedTuple or AbstractDict version instead.
11121114
11131115
# Example
11141116
```jldoctest
11151117
julia> using DynamicPPL, Distributions
11161118
1117-
julia> @model function demo(xs)
1118-
s ~ InverseGamma(2, 3)
1119-
m_shifted ~ Normal(10, √s)
1120-
m = m_shifted - 10
1121-
for i in eachindex(xs)
1122-
xs[i] ~ Normal(m, √s)
1123-
end
1124-
return (m, )
1119+
julia> @model function demo()
1120+
m ~ Normal()
1121+
return (mp1 = m + 1,)
11251122
end
11261123
demo (generic function with 2 methods)
11271124
1128-
julia> model = demo(randn(10));
1129-
1130-
julia> parameters = (; s = 1.0, m_shifted=10.0);
1125+
julia> model = demo();
11311126
1132-
julia> returned(model, parameters)
1133-
(0.0,)
1127+
julia> returned(model, (; m = 1.0))
1128+
(mp1 = 2.0,)
11341129
1135-
julia> returned(model, values(parameters), keys(parameters))
1136-
(0.0,)
1130+
julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
1131+
(mp1 = 3.0,)
11371132
```
11381133
"""
1139-
function returned(model::Model, parameters::NamedTuple)
1140-
fixed_model = fix(model, parameters)
1141-
return fixed_model()
1142-
end
1143-
1144-
function returned(model::Model, values, keys)
1145-
return returned(model, NamedTuple{keys}(values))
1134+
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
1135+
vi = DynamicPPL.setaccs!!(VarInfo(), ())
1136+
# Note: we can't use `fix(model, parameters)` because
1137+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
1138+
# Use `nothing` as the fallback to ensure that any missing parameters cause an error
1139+
ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing))
1140+
new_model = setleafcontext(model, ctx)
1141+
# We can't use new_model() because that overwrites it with an InitContext of its own.
1142+
return first(evaluate!!(new_model, vi))
11461143
end
1144+
Base.@deprecate returned(model::Model, values, keys) returned(
1145+
model, NamedTuple{keys}(values)
1146+
)

test/model.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,38 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
321321
end
322322
end
323323

324+
@testset "returned() on NamedTuple / Dict" begin
325+
@model function demo_returned()
326+
a ~ Normal()
327+
b ~ Normal()
328+
return (asq=a^2, bsq=b^2)
329+
end
330+
model = demo_returned()
331+
332+
@testset "NamedTuple" begin
333+
params = (a=1.0, b=2.0)
334+
results = returned(model, params)
335+
@test results.asq == params.a^2
336+
@test results.bsq == params.b^2
337+
# `returned` should error when not all parameters are provided
338+
@test_throws ErrorException returned(model, (; a=1.0))
339+
@test_throws ErrorException returned(model, (a=1.0, b=missing))
340+
end
341+
@testset "Dict" begin
342+
params = Dict{VarName,Float64}(@varname(a) => 1.0, @varname(b) => 2.0)
343+
results = returned(model, params)
344+
@test results.asq == params[@varname(a)]^2
345+
@test results.bsq == params[@varname(b)]^2
346+
# `returned` should error when not all parameters are provided
347+
@test_throws ErrorException returned(
348+
model, Dict{VarName,Float64}(@varname(a) => 1.0)
349+
)
350+
@test_throws ErrorException returned(
351+
model, Dict{VarName,Any}(@varname(a) => 1.0, @varname(b) => missing)
352+
)
353+
end
354+
end
355+
324356
@testset "returned() on `LKJCholesky`" begin
325357
n = 10
326358
d = 2

0 commit comments

Comments
 (0)