@@ -24,9 +24,7 @@ A struct containing the results of Pareto-smoothed importance sampling.
2424
2525# Fields
2626
27- - `log_weights`: A vector of smoothed and truncated but *unnormalized* importance sampling
28- weights.
29- - `weights`: A lazy
27+ - `weights`: A vector of smoothed, truncated, and normalized importance sampling weights.
3028 - `pareto_k`: Estimates of the shape parameter `k` of the generalized Pareto distribution.
3129 - `ess`: Estimated effective sample size for each LOO evaluation, based on the variance of
3230 the weights.
@@ -54,6 +52,23 @@ struct Psis{
5452end
5553
5654
55+ function Base. getproperty (psis_obj:: Psis , k:: Symbol )
56+ if k === :log_weights
57+ return log .(getfield (psis_obj, :weights ))
58+ else
59+ return getfield (psis_obj, k)
60+ end
61+ end
62+
63+
64+ function Base. propertynames (psis_object:: Psis )
65+ return (
66+ fieldnames (typeof (psis_object))... ,
67+ :log_weights ,
68+ )
69+ end
70+
71+
5772function Base. show (io:: IO , :: MIME"text/plain" , psis_object:: Psis )
5873 table = hcat (psis_object. pareto_k, psis_object. ess, psis_object. sup_ess)
5974 post_samples = psis_object. posterior_sample_size
7994"""
8095 psis(
8196 log_ratios::AbstractArray{T<:Real},
82- r_eff::AbstractVector;
97+ r_eff::AbstractVector{T} ;
8398 source::String="mcmc"
8499 ) -> Psis
85100
@@ -100,17 +115,17 @@ Implements Pareto-smoothed importance sampling (PSIS).
100115 - `source::String="mcmc"`: A string or symbol describing the source of the sample being
101116 used. If `"mcmc"`, adjusts ESS for autocorrelation. Otherwise, samples are assumed to be
102117 independent. Currently permitted values are $SAMPLE_SOURCES .
103- - `log_weights ::Bool`: If `true`
104- - `calc_ess::Bool = true`
118+ - `calc_ess ::Bool=true `: If `false`, do not calculate ESS diagnostics. Attempting to
119+ access ESS diagnostics will return an empty list.
105120
106121See also: [`relative_eff`]@ref, [`psis_loo`]@ref, [`psis_ess`]@ref.
107122"""
108123function psis (
109- log_ratios:: AbstractArray{<:Real , 3} ;
110- r_eff:: AbstractVector{<:Real } = similar (log_ratios, 0 ),
124+ log_ratios:: AbstractArray{T , 3} ;
125+ r_eff:: AbstractVector{T } = similar (log_ratios, 0 ),
111126 source:: Union{AbstractString, Symbol} = " mcmc" ,
112127 calc_ess:: Bool = true
113- )
128+ ) where T <: Real
114129
115130 source = lowercase (String (source))
116131 dims = size (log_ratios)
@@ -131,7 +146,7 @@ function psis(
131146 ξ = similar (r_eff)
132147 @inbounds Threads. @threads for i in eachindex (tail_length)
133148 tail_length[i] = _def_tail_length (post_sample_size, r_eff[i])
134- ξ[i] = @views psis! (weights_mat[i, :], tail_length[i])
149+ ξ[i] = @views psis! (weights_mat[i, :], r_eff[i]; tail_length = tail_length[i])
135150 end
136151
137152 @tullio norm_const[i] := weights[i, j, k]
@@ -142,10 +157,8 @@ function psis(
142157 ess = psis_ess (weights_mat, r_eff)
143158 inf_ess = sup_ess (weights_mat, r_eff)
144159 else
145- ess = similar (weights_mat, 1 )
146- inf_ess = similar (weights_mat, 1 )
147- ess .= NaN
148- inf_ess .= NaN
160+ ess = similar (weights_mat, 0 )
161+ inf_ess = similar (weights_mat, 0 )
149162 end
150163
151164 return Psis (
@@ -207,9 +220,10 @@ log-weights.
207220Unlike the methods for arrays, `psis!` performs no checks to make sure the input values are
208221valid.
209222"""
210- function psis! (is_ratios:: AbstractVector{<:Real} , tail_length:: Integer ;
223+ function psis! (is_ratios:: AbstractVector{T} , r_eff:: T = one (T);
224+ tail_length:: Integer = _def_tail_length (length (is_ratios), r_eff),
211225 log_weights:: Bool = false
212- )
226+ ) where T <: Real
213227
214228 len = length (is_ratios)
215229 tail_start = len - tail_length + 1 # index of smallest tail value
@@ -227,7 +241,7 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
227241
228242 # Get value just before the tail starts:
229243 cutoff = is_ratios[tail_start - 1 ]
230- ξ = _psis_smooth_tail! (tail, cutoff)
244+ ξ = _psis_smooth_tail! (tail, cutoff, r_eff )
231245
232246 # truncate at max of raw weights (1 after scaling)
233247 clamp! (is_ratios, 0 , 1 )
@@ -242,38 +256,33 @@ function psis!(is_ratios::AbstractVector{<:Real}, tail_length::Integer;
242256end
243257
244258
245- function psis! (is_ratios:: AbstractVector{<:Real} , r_eff:: Real = 1 )
246- tail_length = _def_tail_length (length (is_ratios), r_eff)
247- return psis! (is_ratios, tail_length)
248- end
249-
250-
251259"""
252260 _def_tail_length(log_ratios::AbstractVector, r_eff::Real) -> Integer
253261
254262Define the tail length as in Vehtari et al. (2019), with the small addition that the tail
255263must a multiple of `32*bit_length` (which improves performance).
256264"""
257- function _def_tail_length (length:: Integer , r_eff:: Real = 1 )
265+ function _def_tail_length (length:: Integer , r_eff:: Real = one (T) )
258266 return min (cld (length, 5 ), ceil (3 * sqrt (length / r_eff))) |> Int
259267end
260268
261269
262270"""
263- _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T) where {T<:Real} -> ξ::T
271+ _psis_smooth_tail!(tail::AbstractVector{T}, cutoff::T, r_eff::T=1) where {T<:Real}
272+ -> ξ::T
264273
265274Takes an *already sorted* vector of observations from the tail and smooths it *in place*
266275with PSIS before returning shape parameter `ξ`.
267276"""
268- function _psis_smooth_tail! (tail:: AbstractVector{T} , cutoff:: T ) where {T <: Real }
277+ function _psis_smooth_tail! (tail:: AbstractVector{T} , cutoff:: T , r_eff :: T = one (T) ) where {T <: Real }
269278 len = length (tail)
270279 if any (isinf .(tail))
271280 return ξ = Inf
272281 else
273282 @. tail = tail - cutoff
274283
275284 # save time not sorting since tail is already sorted
276- ξ, σ = gpdfit (tail)
285+ ξ, σ = gpd_fit (tail, r_eff )
277286 @. tail = gpd_quantile (($ (1 : len) - 0.5 ) / len, ξ, σ) + cutoff
278287 end
279288 return ξ
0 commit comments