From 7482a38ddf1f2811e405e07757105872ad5a89e1 Mon Sep 17 00:00:00 2001 From: rofinn Date: Tue, 28 May 2019 16:02:05 -0500 Subject: [PATCH] Clean eweights code to use the default `Weights` type and support alternate methods. --- Project.toml | 3 +- docs/src/weights.md | 17 +++++----- src/StatsBase.jl | 3 +- src/weights.jl | 80 +++++++++++++++++++++++++++++++-------------- test/runtests.jl | 1 + test/weights.jl | 38 ++++++++++++++++++--- 6 files changed, 101 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index 69967992b..215bf0b43 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,9 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extras] +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DelimitedFiles", "Test"] +test = ["Dates", "DelimitedFiles", "Test"] diff --git a/docs/src/weights.md b/docs/src/weights.md index 424d207f9..94f02ebad 100644 --- a/docs/src/weights.md +++ b/docs/src/weights.md @@ -41,23 +41,22 @@ w = ProbabilityWeights([0.2, 0.1, 0.3]) w = pweights([0.2, 0.1, 0.3]) ``` -### `ExponentialWeights` +### `Weights` -Exponential weights are a common form of temporal weights which assign exponentially decreasing -weight to past observations. +The `Weights` type describes a generic weights vector which does not support all operations possible for `FrequencyWeights`, `AnalyticWeights` and `ProbabilityWeights`. ```julia -w = ExponentialWeights([0.1837, 0.2222, 0.2688, 0.3253]) -w = eweights(4, 0.173) # construction based on length and rate parameter +w = Weights([1., 2., 3.]) +w = weights([1., 2., 3.]) ``` -### `Weights` +### `eweights` -The `Weights` type describes a generic weights vector which does not support all operations possible for `FrequencyWeights`, `AnalyticWeights` and `ProbabilityWeights`. +Exponential weights are a common form of temporal weights which assign exponentially decreasing +weight to past observations. ```julia -w = Weights([1., 2., 3.]) -w = weights([1., 2., 3.]) +w = eweights(4, 0.173) # construction based on length and rate parameter ``` ## Methods diff --git a/src/StatsBase.jl b/src/StatsBase.jl index 46cd90aad..102af4ab9 100644 --- a/src/StatsBase.jl +++ b/src/StatsBase.jl @@ -30,12 +30,11 @@ export AnalyticWeights, # to represent an analytic/precision/reliability weight vector FrequencyWeights, # to representing a frequency/case/repeat weight vector ProbabilityWeights, # to representing a probability/sampling weight vector - ExponentialWeights, # to represent an exponential weight vector weights, # construct a generic Weights vector aweights, # construct an AnalyticWeights vector fweights, # construct a FrequencyWeights vector pweights, # construct a ProbabilityWeights vector - eweights, # construct an ExponentialWeights vector + eweights, # construct an exponential Weights vector wsum, # weighted sum with vector as second argument wsum!, # weighted sum across dimensions with provided storage wmean, # weighted mean diff --git a/src/weights.jl b/src/weights.jl index 0e2f6af2e..06b803651 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -193,34 +193,42 @@ pweights(vs::RealArray) = ProbabilityWeights(vec(vs)) end end -@weights ExponentialWeights - -@doc """ - ExponentialWeights(vs, wsum=sum(vs)) - -Construct an `ExponentialWeights` vector with weight values `vs`. -A precomputed sum may be provided as `wsum`. +""" + eweights(t::AbstractVector{<:Integer}, λ::Real) + eweights(t::AbstractVector{T}, r::StepRange{T}, λ::Real) where T + eweights(n::Integer, λ::Real) -Exponential weights are a common form of temporal weights which assign exponentially -decreasing weight to past observations, which in this case corresponds to the front of -the vector. That is, newer observations are assumed to be at the end. -""" ExponentialWeights +Construct [`Weights`](@ref) vector which assigns exponentially decreasing weights to past +observations, which in this case corresponds to larger integer values `i` in `t`. -""" - eweights(n, λ) +For each element `i` in `t` the weight value is computed as: -Construct an [`ExponentialWeights`](@ref) vector with length `n`, -where each element in position ``i`` is set to ``λ (1 - λ)^{1 - i}``. +``λ (1 - λ)^{1 - i}`` ``λ`` is a smoothing factor or rate parameter such that ``0 < λ \\leq 1``. As this value approaches 0, the resulting weights will be almost equal, while values closer to 1 will put greater weight on the tail elements of the vector. # Examples +```julia-repl +julia> eweights(1:10, 0.3) +10-element Weights{Float64,Float64,Array{Float64,1}}: + 0.3 + 0.42857142857142855 + 0.6122448979591837 + 0.8746355685131197 + 1.249479383590171 + 1.7849705479859588 + 2.549957925694227 + 3.642797036706039 + 5.203995766722913 + 7.434279666747019 +``` +Simply passing the number of observations `n` is equivalent to passing in `1:n`. ```julia-repl julia> eweights(10, 0.3) -10-element ExponentialWeights{Float64,Float64,Array{Float64,1}}: +10-element Weights{Float64,Float64,Array{Float64,1}}: 0.3 0.42857142857142855 0.6122448979591837 @@ -232,20 +240,42 @@ julia> eweights(10, 0.3) 5.203995766722913 7.434279666747019 ``` + +Finally, passing arbitrary times and a step range is equivalent to passing +`something.(indexin(t, r))`. +```julia-repl +julia> eweights([1, 3, 5], 1:10, 0.3) +3-element Weights{Float64,Float64,Array{Float64,1}}: + 0.3 + 0.6122448979591837 + 1.249479383590171 +``` """ +function eweights(t::AbstractVector{T}, λ::Real) where T<:Integer + 0 < λ <= 1 || throw(ArgumentError("Smoothing factor must be between 0 and 1")) + + w0 = map(t) do i + i > 0 || throw(ArgumentError("Time indices must be non-zero positive integers")) + λ * (1 - λ)^(1 - i) + end + + s = sum(w0) + Weights{typeof(s), eltype(w0), typeof(w0)}(w0, s) +end + function eweights(n::Integer, λ::Real) n > 0 || throw(ArgumentError("cannot construct exponential weights of length < 1")) - 0 < λ <= 1 || throw(ArgumentError("smoothing factor must be between 0 and 1")) - w0 = map(i -> λ * (1 - λ)^(1 - i), 1:n) - s = sum(w0) - ExponentialWeights{typeof(s), eltype(w0), typeof(w0)}(w0, s) + eweights(1:n, λ) end +eweights(t::AbstractVector, r::AbstractRange, λ::Real) = + eweights(something.(indexin(t, r)), λ) + # NOTE: No variance correction is implemented for exponential weights ##### Equality tests ##### -for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, ExponentialWeights, Weights) +for w in (AnalyticWeights, FrequencyWeights, ProbabilityWeights, Weights) @eval begin Base.isequal(x::$w, y::$w) = isequal(x.sum, y.sum) && isequal(x.values, y.values) Base.:(==)(x::$w, y::$w) = (x.sum == y.sum) && (x.values == y.values) @@ -531,7 +561,7 @@ _mean(A::AbstractArray{T}, w::AbstractWeights{W}, dims::Int) where {T,W} = Compute the weighted quantiles of a vector `v` at a specified set of probability values `p`, using weights given by a weight vector `w` (of type `AbstractWeights`). Weights must not be negative. The weights and data vectors must have the same length. -`NaN` is returned if `x` contains any `NaN` values. An error is raised if `w` contains +`NaN` is returned if `x` contains any `NaN` values. An error is raised if `w` contains any `NaN` values. With [`FrequencyWeights`](@ref), the function returns the same result as @@ -552,15 +582,15 @@ function quantile(v::RealVector{V}, w::AbstractWeights{W}, p::RealVector) where all(x -> 0 <= x <= 1, p) || throw(ArgumentError("input probability out of [0,1] range")) w.sum == 0 && throw(ArgumentError("weight vector cannot sum to zero")) - length(v) == length(w) || throw(ArgumentError("data and weight vectors must be the same size," * + length(v) == length(w) || throw(ArgumentError("data and weight vectors must be the same size," * "got $(length(v)) and $(length(w))")) for x in w.values isnan(x) && throw(ArgumentError("weight vector cannot contain NaN entries")) x < 0 && throw(ArgumentError("weight vector cannot contain negative entries")) end - isa(w, FrequencyWeights) && !(eltype(w) <: Integer) && any(!isinteger, w) && - throw(ArgumentError("The values of the vector of `FrequencyWeights` must be numerically" * + isa(w, FrequencyWeights) && !(eltype(w) <: Integer) && any(!isinteger, w) && + throw(ArgumentError("The values of the vector of `FrequencyWeights` must be numerically" * "equal to integers. Use `ProbabilityWeights` or `AnalyticWeights` instead.")) # remove zeros weights and sort diff --git a/test/runtests.jl b/test/runtests.jl index dac21a0c8..500539c74 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,5 @@ using StatsBase +using Dates using LinearAlgebra using Random using Statistics diff --git a/test/weights.jl b/test/weights.jl index ecd07f90d..d8412466b 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -450,20 +450,50 @@ end end @testset "ExponentialWeights" begin - @testset "Basic Usage" begin + @testset "Usage" begin θ = 5.25 λ = 1 - exp(-1 / θ) # simple conversion for the more common/readable method - v = [λ*(1-λ)^(1-i) for i = 1:4] - w = ExponentialWeights(v) + w = Weights(v) @test round.(w, digits=4) == [0.1734, 0.2098, 0.2539, 0.3071] - @test eweights(4, λ) ≈ w + + @testset "basic" begin + @test eweights(1:4, λ) ≈ w + end + + @testset "1:n" begin + @test eweights(4, λ) ≈ w + end + + @testset "indexin" begin + v = [λ*(1-λ)^(1-i) for i = 1:10] + + # Test that we should be able to skip indices easily + @test eweights([1, 3, 5, 7], 1:10, λ) ≈ Weights(v[[1, 3, 5, 7]]) + + # This should also work with actual time types + t1 = DateTime(2019, 1, 1, 1) + tx = t1 + Hour(7) + tn = DateTime(2019, 1, 2, 1) + + @test eweights(t1:Hour(2):tx, t1:Hour(1):tn, λ) ≈ Weights(v[[1, 3, 5, 7]]) + end end @testset "Failure Conditions" begin + # n == 0 @test_throws ArgumentError eweights(0, 0.3) + + # λ > 1.0 @test_throws ArgumentError eweights(1, 1.1) + + # time indices are not all positive non-zero integers + @test_throws ArgumentError eweights([0, 1, 2, 3], 0.3) + + # Passing in an array of bools will work because Bool <: Integer, + # but any `false` values will trigger the same argument error as 0.0 + @test_throws ArgumentError eweights([true, false, true, true], 0.3) end end