Skip to content

Commit 9bd8f16

Browse files
authored
Change pointwise_logdensities default key type to VarName (#1071)
* Change pointwise_logdensities default key type to VarName * Fix a doctest
1 parent 01bf0bc commit 9bd8f16

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t
6161
The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead.
6262
`loadstate` is exported from DynamicPPL.
6363

64+
### Change of default keytype of `pointwise_logdensities`
65+
66+
The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`.
67+
6468
**Other changes**
6569

6670
### `predict(model, chain; include_all)`

src/pointwise_logdensities.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,13 @@ end
116116
::Val{whichlogprob}=Val(:both),
117117
)
118118
119-
Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}`
120-
with keys corresponding to symbols of the variables, and values being matrices
121-
of shape `(num_chains, num_samples)`.
119+
Runs `model` on each sample in `chain` returning a `OrderedDict{VarName, Matrix{Float64}}`
120+
with keys being model variables and values being matrices of shape
121+
`(num_chains, num_samples)`.
122122
123123
`keytype` specifies what the type of the keys used in the returned `OrderedDict` are.
124-
Currently, only `String` and `VarName` are supported. `whichlogprob` specifies
125-
which log-probabilities to compute. It can be `:both`, `:prior`, or
124+
Currently, only `String` and `VarName` are supported, with `VarName` being the default.
125+
`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or
126126
`:likelihood`.
127127
128128
See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref).
@@ -177,13 +177,13 @@ julia> # A chain with 3 iterations.
177177
);
178178
179179
julia> pointwise_logdensities(model, chain)
180-
OrderedDict{String, Matrix{Float64}} with 6 entries:
181-
"s" => [-0.802775; -1.38222; -2.09861;;]
182-
"m" => [-8.91894; -7.51551; -7.46824;;]
183-
"xs[1]" => [-5.41894; -5.26551; -5.63491;;]
184-
"xs[2]" => [-2.91894; -3.51551; -4.13491;;]
185-
"xs[3]" => [-1.41894; -2.26551; -2.96824;;]
186-
"y" => [-0.918939; -1.51551; -2.13491;;]
180+
OrderedDict{VarName, Matrix{Float64}} with 6 entries:
181+
s => [-0.802775; -1.38222; -2.09861;;]
182+
m => [-8.91894; -7.51551; -7.46824;;]
183+
xs[1] => [-5.41894; -5.26551; -5.63491;;]
184+
xs[2] => [-2.91894; -3.51551; -4.13491;;]
185+
xs[3] => [-1.41894; -2.26551; -2.96824;;]
186+
y => [-0.918939; -1.51551; -2.13491;;]
187187
188188
julia> pointwise_logdensities(model, chain, String)
189189
OrderedDict{String, Matrix{Float64}} with 6 entries:
@@ -225,7 +225,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])],
225225
```
226226
"""
227227
function pointwise_logdensities(
228-
model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both)
228+
model::Model, chain, ::Type{KeyType}=VarName, ::Val{whichlogprob}=Val(:both)
229229
) where {KeyType,whichlogprob}
230230
# Get the data by executing the model once
231231
vi = VarInfo(model)
@@ -283,7 +283,7 @@ including the likelihood terms.
283283
284284
See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref).
285285
"""
286-
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T}
286+
function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=VarName) where {T}
287287
return pointwise_logdensities(model, chain, T, Val(:likelihood))
288288
end
289289

@@ -301,7 +301,7 @@ including the prior terms.
301301
See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref).
302302
"""
303303
function pointwise_prior_logdensities(
304-
model::Model, chain, keytype::Type{T}=String
304+
model::Model, chain, keytype::Type{T}=VarName
305305
) where {T}
306306
return pointwise_logdensities(model, chain, T, Val(:prior))
307307
end

test/pointwise_logdensities.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ end
6060
loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain)
6161

6262
# Check that they contain the correct variables.
63-
@test all(string(vn) in keys(logjoints_pointwise) for vn in vns)
64-
@test all(string(vn) in keys(logpriors_pointwise) for vn in vns)
65-
@test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise))
66-
@test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns)
67-
@test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise))
63+
@test all(vn in keys(logjoints_pointwise) for vn in vns)
64+
@test all(vn in keys(logpriors_pointwise) for vn in vns)
65+
@test !any(Base.Fix1(subsumes, @varname(x)), keys(logpriors_pointwise))
66+
@test !any(vn in keys(loglikelihoods_pointwise) for vn in vns)
67+
@test all(Base.Fix1(subsumes, @varname(x)), keys(loglikelihoods_pointwise))
6868

6969
# Get the sum of the logjoints for each of the iterations.
7070
logjoints = [

0 commit comments

Comments
 (0)