Skip to content

Commit 1d24d84

Browse files
committed
Support weighted quantiles in cut
This requires adding an extension point for StatsBase. Unfortunately more copies of the data and weights are done than necessary as StatsBase does not support in-place weighted quantile! on pre-sorted data nor taking a view of weights vectors (JuliaStats/StatsBase.jl#723).
1 parent 311e593 commit 1d24d84

File tree

5 files changed

+79
-9
lines changed

5 files changed

+79
-9
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
1616
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
1717
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1818
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
19+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1920
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
2021
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
2122

2223
[extensions]
2324
CategoricalArraysArrowExt = "Arrow"
2425
CategoricalArraysJSONExt = "JSON"
2526
CategoricalArraysRecipesBaseExt = "RecipesBase"
27+
CategoricalArraysStatsBaseExt = "StatsBase"
2628
CategoricalArraysSentinelArraysExt = "SentinelArrays"
2729
CategoricalArraysStructTypesExt = "StructTypes"
2830

@@ -37,6 +39,7 @@ RecipesBase = "1.1"
3739
Requires = "1"
3840
SentinelArrays = "1"
3941
Statistics = "1"
42+
StatsBase = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
4043
StructTypes = "1"
4144
julia = "1.6"
4245

@@ -49,8 +52,9 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
4952
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
5053
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
5154
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
55+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
5256
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
5357
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5458

5559
[targets]
56-
test = ["Arrow", "Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StructTypes", "Test"]
60+
test = ["Arrow", "Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StatsBase", "StructTypes", "Test"]

ext/CategoricalArraysStatsBaseExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module CategoricalArraysStatsBaseExt
2+
3+
if isdefined(Base, :get_extension)
4+
import CategoricalArrays: _wquantile
5+
using StatsBase
6+
else
7+
import ..CategoricalArrays: _wquantile
8+
using ..StatsBase
9+
end
10+
11+
_wquantile(x::AbstractArray, w::AbstractWeights, p::AbstractVector) = quantile(x, w, p)
12+
13+
end

src/CategoricalArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ module CategoricalArrays
4747
@require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl")
4848
@require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl")
4949
@require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl")
50+
@require StatsBase="2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("../ext/CategoricalArraysStatsBaseExt.jl")
5051
@require StructTypes="856f2bd8-1eba-4b0a-8007-ebc267875bd4" include("../ext/CategoricalArraysStructTypesExt.jl")
5152
end
5253
end

src/extras.jl

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,11 +333,17 @@ function find_breaks(v::AbstractVector, qs::AbstractVector)
333333
return breaks
334334
end
335335

336+
# AbstractWeights method is defined in StatsBase extension
337+
# There is no in-place weighted quantile method in StatsBase
338+
_wquantile(x::AbstractArray, w::AbstractVector, p::AbstractVector) =
339+
throw(ArgumentError("`weights` must be an `AbstractWeights` vector from StatsBase.jl"))
340+
336341
"""
337342
cut(x::AbstractArray, ngroups::Integer;
338343
labels::Union{AbstractVector{<:AbstractString},Function},
339344
sigdigits::Integer=3,
340-
allowempty::Bool=false)
345+
allowempty::Bool=false,
346+
weights::Union{AbstractWeights, Nothing}=nothing)
341347
342348
Cut a numeric array into `ngroups` quantiles.
343349
@@ -369,19 +375,41 @@ quantiles.
369375
other than the last one are equal, generating empty intervals;
370376
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
371377
unused levels (but duplicate labels are not allowed).
378+
* `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to used when
379+
computing quantiles (see `quantile` documentation in StatsBase).
372380
"""
373381
function cut(x::AbstractArray, ngroups::Integer;
374382
labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing,
375383
sigdigits::Integer=3,
376-
allowempty::Bool=false)
384+
allowempty::Bool=false,
385+
weights::Union{AbstractVector, Nothing}=nothing)
377386
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
378-
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
379-
min_x, max_x = first(sorted_x), last(sorted_x)
380-
if (min_x isa Number && isnan(min_x)) ||
381-
(max_x isa Number && isnan(max_x))
382-
throw(ArgumentError("NaN values are not allowed in input vector"))
387+
if weights === nothing
388+
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
389+
min_x, max_x = first(sorted_x), last(sorted_x)
390+
if (min_x isa Number && isnan(min_x)) ||
391+
(max_x isa Number && isnan(max_x))
392+
throw(ArgumentError("NaN values are not allowed in input vector"))
393+
end
394+
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
395+
else
396+
if eltype(x) >: Missing
397+
nm_inds = findall(!ismissing, x)
398+
nm_x = view(x, nm_inds)
399+
# TODO: use a view once this is supported (JuliaStats/StatsBase.jl#723)
400+
nm_weights = weights[nm_inds]
401+
else
402+
nm_x = x
403+
nm_weights = weights
404+
end
405+
sorted_x = sort(nm_x)
406+
min_x, max_x = first(sorted_x), last(sorted_x)
407+
if (min_x isa Number && isnan(min_x)) ||
408+
(max_x isa Number && isnan(max_x))
409+
throw(ArgumentError("NaN values are not allowed in input vector"))
410+
end
411+
qs = _wquantile(nm_x, nm_weights, (1:(ngroups-1))/ngroups)
383412
end
384-
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
385413
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
386414
if !allowempty && !allunique(@view breaks[1:end-1])
387415
throw(ArgumentError("cannot compute $ngroups quantiles due to " *

test/15_extras.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module TestExtras
22
using Test
33
using CategoricalArrays
4+
using StatsBase
5+
using Missings
46

57
const = isequal
68

@@ -423,4 +425,26 @@ end
423425

424426
end
425427

428+
@testset "cut with weighted quantiles" begin
429+
@test_throws ArgumentError cut(1:3, 3, weights=1:3)
430+
431+
x = collect(Float64, 1:100)
432+
w = fweights(repeat(1:10, inner=10))
433+
y = cut(x, 10, weights=w)
434+
@test levelcode.(y) == levelcode.(cut(x, quantile(x, w, (0:10)./10)))
435+
@test levels(y) == ["[1, 29)", "[29, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
436+
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]
437+
438+
mx = allowmissing(x)
439+
mx[2] = mx[10] = missing
440+
nm_inds = .!ismissing.(mx)
441+
y = cut(mx, 10, weights=w)
442+
@test levelcode.(y) levelcode.(cut(mx, quantile(x[nm_inds], w[nm_inds], (0:10)./10)))
443+
@test levels(y) == ["[1, 30)", "[30, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
444+
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]
445+
446+
x[5] = NaN
447+
@test_throws ArgumentError cut(x, 3, weights=w)
448+
end
449+
426450
end

0 commit comments

Comments
 (0)