Skip to content

Commit 40d8c20

Browse files
committed
Support weighted quantiles in cut
This requires adding an extension point for StatsBase. Unfortunately more copies of the data and weights are done than necessary as StatsBase does not support in-place weighted quantile! on pre-sorted data nor taking a view of weights vectors (JuliaStats/StatsBase.jl#723).
1 parent 311e593 commit 40d8c20

File tree

5 files changed

+73
-5
lines changed

5 files changed

+73
-5
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
1616
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
1717
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
1818
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
19+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1920
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
2021
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
2122

2223
[extensions]
2324
CategoricalArraysArrowExt = "Arrow"
2425
CategoricalArraysJSONExt = "JSON"
2526
CategoricalArraysRecipesBaseExt = "RecipesBase"
27+
CategoricalArraysStatsBaseExt = "StatsBase"
2628
CategoricalArraysSentinelArraysExt = "SentinelArrays"
2729
CategoricalArraysStructTypesExt = "StructTypes"
2830

@@ -37,6 +39,7 @@ RecipesBase = "1.1"
3739
Requires = "1"
3840
SentinelArrays = "1"
3941
Statistics = "1"
42+
StatsBase = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
4043
StructTypes = "1"
4144
julia = "1.6"
4245

@@ -49,8 +52,9 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
4952
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
5053
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
5154
SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
55+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
5256
StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4"
5357
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5458

5559
[targets]
56-
test = ["Arrow", "Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StructTypes", "Test"]
60+
test = ["Arrow", "Dates", "JSON", "JSON3", "Plots", "PooledArrays", "RecipesBase", "SentinelArrays", "StatsBase", "StructTypes", "Test"]

ext/CategoricalArraysStatsBaseExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module CategoricalArraysStatsBaseExt
2+
3+
if isdefined(Base, :get_extension)
4+
import CategoricalArrays: _quantile!
5+
using StatsBase
6+
else
7+
import ..CategoricalArrays: _quantile!
8+
using ..StatsBase
9+
end
10+
11+
_quantile!(x::AbstractArray, w::AbstractWeights, p::AbstractVector) = quantile(x, w, p)
12+
13+
end

src/CategoricalArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ module CategoricalArrays
4747
@require JSON="682c06a0-de6a-54ab-a142-c8b1cf79cde6" include("../ext/CategoricalArraysJSONExt.jl")
4848
@require RecipesBase="3cdcf5f2-1ef4-517c-9805-6587b60abb01" include("../ext/CategoricalArraysRecipesBaseExt.jl")
4949
@require SentinelArrays="91c51154-3ec4-41a3-a24f-3f23e20d615c" include("../ext/CategoricalArraysSentinelArraysExt.jl")
50+
@require StatsBase="2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" include("../ext/CategoricalArraysStatsBaseExt.jl")
5051
@require StructTypes="856f2bd8-1eba-4b0a-8007-ebc267875bd4" include("../ext/CategoricalArraysStructTypesExt.jl")
5152
end
5253
end

src/extras.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,11 +333,19 @@ function find_breaks(v::AbstractVector, qs::AbstractVector)
333333
return breaks
334334
end
335335

336+
_quantile!(x::AbstractArray, w::Nothing, p::AbstractVector) =
337+
quantile!(x, p, sorted=true)
338+
# AbstractWeights method is defined in StatsBase extension
339+
# There is no in-place weighted quantile method in StatsBase
340+
_quantile!(x::AbstractArray, w::AbstractVector, p::AbstractVector) =
341+
throw(ArgumentError("`weights` must be an `AbstractWeights` vector from StatsBase.jl"))
342+
336343
"""
337344
cut(x::AbstractArray, ngroups::Integer;
338345
labels::Union{AbstractVector{<:AbstractString},Function},
339346
sigdigits::Integer=3,
340-
allowempty::Bool=false)
347+
allowempty::Bool=false,
348+
weights::Union{AbstractWeights, Nothing}=nothing)
341349
342350
Cut a numeric array into `ngroups` quantiles.
343351
@@ -369,19 +377,39 @@ quantiles.
369377
other than the last one are equal, generating empty intervals;
370378
when `true`, duplicate breaks are allowed and the intervals they generate are kept as
371379
unused levels (but duplicate labels are not allowed).
380+
* `weights::Union{AbstractWeights, Nothing}=nothing`: observations weights to pass to `quantile`.
372381
"""
373382
function cut(x::AbstractArray, ngroups::Integer;
374383
labels::Union{AbstractVector{<:SupportedTypes},Function,Nothing}=nothing,
375384
sigdigits::Integer=3,
376-
allowempty::Bool=false)
385+
allowempty::Bool=false,
386+
weights::Union{AbstractVector, Nothing}=nothing)
377387
ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)"))
378-
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
388+
if weights === nothing
389+
sorted_x = eltype(x) >: Missing ? sort!(collect(skipmissing(x))) : sort(x)
390+
min_x, max_x = first(sorted_x), last(sorted_x)
391+
if (min_x isa Number && isnan(min_x)) ||
392+
(max_x isa Number && isnan(max_x))
393+
throw(ArgumentError("NaN values are not allowed in input vector"))
394+
end
395+
else
396+
if eltype(x) >: Missing
397+
nm_inds = findall(!ismissing, x)
398+
nm_x = view(x, nm_inds)
399+
# TODO: use a view once this is supported (JuliaStats/StatsBase.jl#723)
400+
nm_weights = weights[nm_inds]
401+
else
402+
nm_x = x
403+
nm_weights = weights
404+
end
405+
sorted_x = sort(nm_x)
406+
end
379407
min_x, max_x = first(sorted_x), last(sorted_x)
380408
if (min_x isa Number && isnan(min_x)) ||
381409
(max_x isa Number && isnan(max_x))
382410
throw(ArgumentError("NaN values are not allowed in input vector"))
383411
end
384-
qs = quantile!(sorted_x, (1:(ngroups-1))/ngroups, sorted=true)
412+
qs = _quantile!(nm_x, nm_weights, (1:(ngroups-1))/ngroups)
385413
breaks = [min_x; find_breaks(sorted_x, qs); max_x]
386414
if !allowempty && !allunique(@view breaks[1:end-1])
387415
throw(ArgumentError("cannot compute $ngroups quantiles due to " *

test/15_extras.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,4 +423,26 @@ end
423423

424424
end
425425

426+
@testset "cut with weighted quantiles" begin
427+
@test_throws ArgumentError cut(1:3, 3, weights=1:3)
428+
429+
x = collect(Float64, 1:100)
430+
w = fweights(repeat(1:10, inner=10))
431+
y = cut(x, 10, weights=w)
432+
@test levelcode.(y) == levelcode.(cut(x, quantile(x, w, (0:10)./10)))
433+
@test levels(y) == ["[1, 29)", "[29, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
434+
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]
435+
436+
mx = allowmissing(x)
437+
mx[2] = mx[10] = missing
438+
nm_inds = .!ismissing.(mx)
439+
y = cut(mx, 10, weights=w)
440+
@test levelcode.(y) levelcode.(cut(mx, quantile(x[nm_inds], w[nm_inds], (0:10)./10)))
441+
@test levels(y) == ["[1, 30)", "[30, 43)", "[43, 53)", "[53, 62)", "[62, 70)",
442+
"[70, 77)", "[77, 83)", "[83, 89)", "[89, 95)", "[95, 100]"]
443+
444+
x[5] = NaN
445+
@test_throws ArgumentError cut(x, 3, weights=w)
446+
end
447+
426448
end

0 commit comments

Comments
 (0)