Skip to content

Commit 38f5776

Browse files
author
Closed-Limelike-Curves
committed
add log_weights
1 parent 9a71830 commit 38f5776

File tree

5 files changed

+46
-46
lines changed

5 files changed

+46
-46
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1212
MCMCDiagnosticTools = "be115224-59cd-429b-ad48-344e309966f0"
13+
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
1314
NamedDims = "356022a1-0364-5f58-8944-0da4b18d706f"
1415
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1516
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

src/GPD.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using Tullio
55

66

77
"""
8-
gpdfit(
9-
sample::AbstractVector{T<:Real};
8+
gpd_fit(
9+
sample::AbstractVector{T<:Real},
10+
r_eff::T = 1;
1011
wip::Bool=true,
1112
min_grid_pts::Integer=30,
1213
sort_sample::Bool=false
@@ -29,12 +30,13 @@ generalized Pareto distribution (GPD), assuming the location parameter is 0.
2930
Estimation method taken from Zhang, J. and Stephens, M.A. (2009). The parameter ξ is the
3031
negative of k.
3132
"""
32-
function gpdfit(
33-
sample::AbstractVector{T};
33+
function gpd_fit(
34+
sample::AbstractVector{T},
35+
r_eff::T=1;
3436
wip::Bool=true,
3537
min_grid_pts::Integer=30,
3638
sort_sample::Bool=false,
37-
) where {T <: Real}
39+
) where T<:Real
3840

3941
len = length(sample)
4042
# sample must be sorted, but we can skip if sample is already sorted
@@ -70,7 +72,7 @@ function gpdfit(
7072

7173
# Drag towards .5 to reduce variance for small len
7274
if wip
73-
@fastmath ξ =* len + 0.5 * n_0) / (len + n_0)
75+
@fastmath ξ = (r_eff * ξ * len + 0.5 * n_0) / (r_eff * len + n_0)
7476
end
7577

7678
return ξ, σ

src/ImportanceSampling.jl

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling.
2424
2525
# Fields
2626
27-
- `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
28-
weights.
29-
- `weights`: A lazy
27+
- `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
3028
- `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
3129
- `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
3230
the weights.
@@ -54,6 +52,23 @@ struct Psis{
5452
end
5553

5654

55+
function Base.getproperty(psis_obj::Psis, k::Symbol)
56+
if k === :log_weights
57+
return log.(getfield(psis_obj, :weights))
58+
else
59+
return getfield(psis_obj, k)
60+
end
61+
end
62+
63+
64+
function Base.propertynames(psis_object::Psis)
65+
return (
66+
fieldnames(typeof(psis_object))...,
67+
:log_weights,
68+
)
69+
end
70+
71+
5772
function Base.show(io::IO, ::MIME"text/plain", psis_object::Psis)
5873
table = hcat(psis_object.pareto_k, psis_object.ess, psis_object.sup_ess)
5974
post_samples = psis_object.posterior_sample_size
@@ -79,7 +94,7 @@ end
7994
"""
8095
psis(
8196
log_ratios::AbstractArray{T<:Real},
82-
r_eff::AbstractVector;
97+
r_eff::AbstractVector{T};
8398
source::String="mcmc"
8499
) -> Psis
85100
@@ -100,17 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS).
100115
- `source::String="mcmc"`: A string or symbol describing the source of the sample being
101116
used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be
102117
independent. Currently permitted values are $SAMPLE_SOURCES.
103-
- `log_weights::Bool`: If `true`
104-
- `calc_ess::Bool = true`
118+
- `calc_ess::Bool=true`: If `false`, do not calculate ESS diagnostics. Attempting to
119+
access ESS diagnostics will return an empty list.
105120
106121
See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
107122
"""
108123
function psis(
109-
log_ratios::AbstractArray{<:Real, 3};
110-
r_eff::AbstractVector{<:Real}=similar(log_ratios, 0),
124+
log_ratios::AbstractArray{T, 3};
125+
r_eff::AbstractVector{T}=similar(log_ratios, 0),
111126
source::Union{AbstractString, Symbol}="mcmc",
112127
calc_ess::Bool = true
113-
)
128+
) where T <: Real
114129

115130
source = lowercase(String(source))
116131
dims = size(log_ratios)
@@ -131,7 +146,7 @@ function psis(
131146
ξ = similar(r_eff)
132147
@inbounds Threads.@threads for i in eachindex(tail_length)
133148
tail_length[i] = _def_tail_length(post_sample_size, r_eff[i])
134-
ξ[i] = @views psis!(weights_mat[i, :], tail_length[i])
149+
ξ[i] = @views psis!(weights_mat[i, :], r_eff[i]; tail_length=tail_length[i])
135150
end
136151

137152
@tullio norm_const[i] := weights[i, j, k]
@@ -142,10 +157,8 @@ function psis(
142157
ess = psis_ess(weights_mat, r_eff)
143158
inf_ess = sup_ess(weights_mat, r_eff)
144159
else
145-
ess = similar(weights_mat, 1)
146-
inf_ess = similar(weights_mat, 1)
147-
ess .= NaN
148-
inf_ess .= NaN
160+
ess = similar(weights_mat, 0)
161+
inf_ess = similar(weights_mat, 0)
149162
end
150163

151164
return Psis(
@@ -207,9 +220,10 @@ log-weights.
207220
Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
208221
valid.
209222
"""
210-
function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
223+
function psis!(is_ratios::AbstractVector{T}, r_eff::T=one(T);
224+
tail_length::Integer = _def_tail_length(length(is_ratios), r_eff),
211225
log_weights::Bool=false
212-
)
226+
) where T<:Real
213227

214228
len = length(is_ratios)
215229
tail_start = len - tail_length + 1 # index of smallest tail value
@@ -227,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
227241

228242
# Get value just before the tail starts:
229243
cutoff = is_ratios[tail_start - 1]
230-
ξ = _psis_smooth_tail!(tail, cutoff)
244+
ξ = _psis_smooth_tail!(tail, cutoff, r_eff)
231245

232246
# truncate at max of raw weights (1 after scaling)
233247
clamp!(is_ratios, 0, 1)
@@ -242,38 +256,33 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
242256
end
243257

244258

245-
function psis!(is_ratios::AbstractVector{<:Real}, r_eff::Real=1)
246-
tail_length = _def_tail_length(length(is_ratios), r_eff)
247-
return psis!(is_ratios, tail_length)
248-
end
249-
250-
251259
"""
252260
_def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
253261
254262
Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
255263
must a multiple of `32*bit_length` (which improves performance).
256264
"""
257-
function _def_tail_length(length::Integer, r_eff::Real=1)
265+
function _def_tail_length(length::Integer, r_eff::Real=one(T))
258266
return min(cld(length, 5), ceil(3 * sqrt(length / r_eff))) |> Int
259267
end
260268

261269

262270
"""
263-
_psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T
271+
_psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real}
272+
-> ξ::T
264273
265274
Takes an *already sorted* vector of observations from the tail and smooths it *in place*
266275
with PSIS before returning shape parameter `ξ`.
267276
"""
268-
function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T <: Real}
277+
function _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=one(T)) where {T <: Real}
269278
len = length(tail)
270279
if any(isinf.(tail))
271280
return ξ = Inf
272281
else
273282
@. tail = tail - cutoff
274283

275284
# save time not sorting since tail is already sorted
276-
ξ, σ = gpdfit(tail)
285+
ξ, σ = gpd_fit(tail, r_eff)
277286
@. tail = gpd_quantile(($(1:len) - 0.5) / len, ξ, σ) + cutoff
278287
end
279288
return ξ

src/TuringHelpers.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,7 @@ const TURING_MODEL_ARG = """
88

99

1010
"""
11-
<<<<<<< HEAD
1211
pointwise_log_likelihoods(model::DynamicPPL.Model, chains::Chains) -> Array
13-
=======
14-
-> Array
15-
>>>>>>> main
1612
1713
Compute pointwise log-likelihoods from a Turing model.
1814
@@ -63,11 +59,7 @@ end
6359

6460

6561
"""
66-
<<<<<<< HEAD
6762
loo_from_psis(model::DynamicPPL.Model, chains::Chains, args...; kwargs...) -> PsisLoo
68-
=======
69-
psis_loo(model::DynamicPPL.Model, chains::Chains, psis::Psis) -> PsisLoo
70-
>>>>>>> main
7163
7264
Use Pareto-Smoothed Importance Sampling to calculate the leave-one-out cross validation
7365
score from a `Chains` object, a Turing model, and a precalculated `Psis` object.
@@ -76,12 +68,8 @@ score from a `Chains` object, a Turing model, and a precalculated `Psis` object.
7668
7769
- $CHAINS_ARG
7870
- $TURING_MODEL_ARG
79-
<<<<<<< HEAD
80-
81-
=======
8271
- `psis`: A `Psis` object containing the results of Pareto smoothed importance sampling.
8372
84-
>>>>>>> main
8573
See also: [`psis`](@ref), [`psis_loo`](@ref), [`PsisLoo`](@ref).
8674
"""
8775
function loo_from_psis(model::DynamicPPL.Model, chains::Chains, psis::Psis)

test/tests/BasicTests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ import RData
5555
# RMSE less than .2% when using InferenceDiagnostics' ESS
5656
@test sqrt(mean((jul_psis.weights ./ r_weights .- 1) .^ 2)) 0.002
5757
# Max difference is 1%
58-
@test maximum(log.(jul_psis.weights) .- log.(r_weights)) 0.01
58+
@test maximum(log.(jul_psis.weights) .- log.(r_weights)) 0.02
5959

6060

6161
## Test difference in loo pointwise results

0 commit comments

Comments
 (0)