Skip to content

Fix allocations by dropping CategoricalPool type parameter #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ function CategoricalArray{T, N, R}(::UndefInitializer, dims::NTuple{N,Int};
U = leveltype(nonmissingtype(T))
S = T >: Missing ? Union{U, Missing} : U
check_supported_eltype(S, T)
V = CategoricalValue{U, R}
levs = levels === nothing ? U[] : collect(U, levels)
CategoricalArray{S, N}(zeros(R, dims), CategoricalPool{U, R, V}(levs, ordered))
CategoricalArray{S, N}(zeros(R, dims), CategoricalPool{U, R}(levs, ordered))
end

CategoricalArray{T, N}(::UndefInitializer, dims::NTuple{N,Int};
Expand Down
14 changes: 6 additions & 8 deletions src/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ const catpool_seed = UInt === UInt32 ? 0xe3cf1386 : 0x356f2c715023f1a5

hashlevels(levs::AbstractVector) = foldl((h, x) -> hash(x, h), levs, init=catpool_seed)

CategoricalPool{T, R, V}(ordered::Bool=false) where {T, R, V} =
CategoricalPool{T, R, V}(T[], ordered)
CategoricalPool{T, R}(ordered::Bool=false) where {T, R} =
CategoricalPool{T, R}(T[], ordered)
CategoricalPool{T}(ordered::Bool=false) where {T} =
CategoricalPool{T, DefaultRefType}(T[], ordered)

CategoricalPool{T, R}(levels::AbstractVector, ordered::Bool=false) where {T, R} =
CategoricalPool{T, R, CategoricalValue{T, R}}(convert(Vector{T}, levels), ordered)
CategoricalPool{T, R}(convert(Vector{T}, levels), ordered)
CategoricalPool(levels::AbstractVector{T}, ordered::Bool=false) where {T} =
CategoricalPool{T, DefaultRefType}(convert(Vector{T}, levels), ordered)

CategoricalPool(invindex::Dict{T, R}, ordered::Bool=false) where {T, R <: Integer} =
CategoricalPool{T, R, CategoricalValue{T, R}}(invindex, ordered)
CategoricalPool{T, R}(invindex, ordered)

Base.convert(::Type{T}, pool::T) where {T <: CategoricalPool} = pool

Expand All @@ -29,12 +27,12 @@ function Base.convert(::Type{CategoricalPool{T, R}}, pool::CategoricalPool) wher

levelsT = convert(Vector{T}, pool.levels)
invindexT = convert(Dict{T, R}, pool.invindex)
return CategoricalPool{T, R, CategoricalValue{T, R}}(levelsT, invindexT, pool.ordered)
return CategoricalPool{T, R}(levelsT, invindexT, pool.ordered)
end

Base.copy(pool::CategoricalPool{T, R, V}) where {T, R, V} =
CategoricalPool{T, R, V}(copy(pool.levels), copy(pool.invindex),
pool.ordered, pool.hash)
Base.copy(pool::CategoricalPool{T, R}) where {T, R} =
CategoricalPool{T, R}(copy(pool.levels), copy(pool.invindex),
pool.ordered, pool.hash)

function Base.show(io::IO, pool::CategoricalPool{T, R}) where {T, R}
@static if VERSION >= v"1.6.0"
Expand Down
39 changes: 16 additions & 23 deletions src/typedefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,27 @@ const SupportedTypes = Union{AbstractString, AbstractChar, Number}
# Type params:
# * `T` type of categorized values
# * `R` integer type for referencing category levels
# * `V` categorical value type
mutable struct CategoricalPool{T <: SupportedTypes, R <: Integer, V}
mutable struct CategoricalPool{T <: SupportedTypes, R <: Integer}
levels::Vector{T} # category levels ordered by their reference codes
invindex::Dict{T, R} # map from category levels to their reference codes
ordered::Bool # whether levels can be compared using <
hash::Union{UInt, Nothing} # hash of levels
subsetof::Ptr{Nothing} # last seen strict superset pool
equalto::Ptr{Nothing} # last seen equal pool

function CategoricalPool{T, R, V}(levels::Vector{T},
ordered::Bool) where {T, R, V}
function CategoricalPool{T, R}(levels::Vector{T},
ordered::Bool) where {T, R}
if length(levels) > typemax(R)
throw(LevelsException{T, R}(levels[Int(typemax(R))+1:end]))
end
invindex = Dict{T, R}(v => i for (i, v) in enumerate(levels))
if length(invindex) != length(levels)
throw(ArgumentError("Duplicate entries are not allowed in levels"))
end
CategoricalPool{T, R, V}(levels, invindex, ordered)
CategoricalPool{T, R}(levels, invindex, ordered)
end
function CategoricalPool{T, R, V}(invindex::Dict{T, R},
ordered::Bool) where {T, R, V}
function CategoricalPool{T, R}(invindex::Dict{T, R},
ordered::Bool) where {T, R}
levels = Vector{T}(undef, length(invindex))
# If invindex contains non consecutive values, a BoundsError will be thrown
try
Expand All @@ -40,18 +39,12 @@ mutable struct CategoricalPool{T <: SupportedTypes, R <: Integer, V}
if length(invindex) > typemax(R)
throw(LevelsException{T, R}(levels[typemax(R)+1:end]))
end
CategoricalPool{T, R, V}(levels, invindex, ordered)
CategoricalPool{T, R}(levels, invindex, ordered)
end
function CategoricalPool{T, R, V}(levels::Vector{T},
invindex::Dict{T, R},
ordered::Bool,
hash::Union{UInt, Nothing}=nothing) where {T, R, V}
if !(V <: CategoricalValue)
throw(ArgumentError("Type $V is not a categorical value type"))
end
if V !== CategoricalValue{T, R}
throw(ArgumentError("V must be CategoricalValue{T, R}"))
end
function CategoricalPool{T, R}(levels::Vector{T},
invindex::Dict{T, R},
ordered::Bool,
hash::Union{UInt, Nothing}=nothing) where {T, R}
pool = new(levels, invindex, ordered, hash, C_NULL, C_NULL)
return pool
end
Expand All @@ -77,7 +70,7 @@ the order of the pool's [`levels`](@ref DataAPI.levels) is used rather than the
ordering of values of type `T`.
"""
struct CategoricalValue{T <: SupportedTypes, R <: Integer}
pool::CategoricalPool{T, R, CategoricalValue{T, R}}
pool::CategoricalPool{T, R}
ref::R
end

Expand All @@ -98,14 +91,14 @@ const AbstractCategoricalMatrix{T, R, V, C, U} = AbstractCategoricalArray{T, 2,

mutable struct CategoricalArray{T, N, R <: Integer, V, C, U} <: AbstractCategoricalArray{T, N, R, V, C, U}
refs::Array{R, N}
pool::CategoricalPool{V, R, C}
pool::CategoricalPool{V, R}

function CategoricalArray{T, N}(refs::Array{R, N},
pool::CategoricalPool{V, R, C}) where
{T, N, R <: Integer, V, C}
pool::CategoricalPool{V, R}) where
{T, N, R <: Integer, V}
T === V || T == Union{V, Missing} || throw(ArgumentError("T ($T) must be equal to $V or Union{$V, Missing}"))
U = T >: Missing ? Missing : Union{}
new{T, N, R, V, C, U}(refs, pool)
new{T, N, R, V, CategoricalValue{V, R}, U}(refs, pool)
end
end
const CategoricalVector{T, R <: Integer, V, C, U} = CategoricalArray{T, 1, R, V, C, U}
Expand Down
22 changes: 5 additions & 17 deletions test/04_constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,10 @@ using CategoricalArrays: DefaultRefType

@testset "Type parameter constraints" begin
# cannot use categorical value as level type
@test_throws TypeError CategoricalPool{CategoricalValue{Int,UInt8}, UInt8, CategoricalValue{CategoricalValue{Int,UInt8},UInt8}}(
@test_throws TypeError CategoricalPool{CategoricalValue{Int,UInt8}, UInt8}(
Dict{CategoricalValue{Int,UInt8}, UInt8}(), false)
@test_throws TypeError CategoricalPool{CategoricalValue{Int,UInt8}, UInt8, CategoricalValue{CategoricalValue{Int,UInt8},UInt8}}(
@test_throws TypeError CategoricalPool{CategoricalValue{Int,UInt8}, UInt8}(
CategoricalValue{Int,UInt8}[], false)
# cannot use non-categorical value as categorical value type
@test_throws ArgumentError CategoricalPool{Int, UInt8, Int}(Int[], false)
@test_throws ArgumentError CategoricalPool{Int, UInt8, Int}(Dict{Int, UInt8}(), false)
# level type of the pool and categorical value must match
@test_throws ArgumentError CategoricalPool{Int, UInt8, CategoricalValue{String, UInt8}}(Int[], false)
@test_throws ArgumentError CategoricalPool{Int, UInt8, CategoricalValue{String, UInt8}}(Dict{Int, UInt8}(), false)
# reference type of the pool and categorical value must match
@test_throws ArgumentError CategoricalPool{Int, UInt8, CategoricalValue{Int, UInt16}}(Int[], false)
@test_throws ArgumentError CategoricalPool{Int, UInt8, CategoricalValue{Int, UInt16}}(Dict{Int, UInt8}(), false)
# correct types combination
@test CategoricalPool{Int, UInt8, CategoricalValue{Int, UInt8}}(Int[], false) isa CategoricalPool
@test CategoricalPool{Int, UInt8, CategoricalValue{Int, UInt8}}(Dict{Int, UInt8}(), false) isa CategoricalPool
end

@testset "empty CategoricalPool{String}" begin
Expand All @@ -38,7 +26,7 @@ end
@testset "empty CategoricalPool{Int}" begin
pool = CategoricalPool{Int, UInt8}()

@test isa(pool, CategoricalPool{Int, UInt8, CategoricalValue{Int, UInt8}})
@test isa(pool, CategoricalPool{Int, UInt8})

@test isa(pool.levels, Vector{Int})
@test length(pool.levels) == 0
Expand All @@ -50,7 +38,7 @@ end
@testset "CategoricalPool{String, DefaultRefType}(a b c)" begin
pool = CategoricalPool(["a", "b", "c"])

@test isa(pool, CategoricalPool{String, UInt32, CategoricalValue{String, UInt32}})
@test isa(pool, CategoricalPool{String, UInt32})

@test isa(pool.levels, Vector{String})
@test pool.levels == ["a", "b", "c"]
Expand Down Expand Up @@ -156,7 +144,7 @@ end
@testset "CategoricalPool{Float64, UInt8}()" begin
pool = CategoricalPool{Float64, UInt8}([1.0, 2.0, 3.0])

@test isa(pool, CategoricalPool{Float64, UInt8, CategoricalValue{Float64, UInt8}})
@test isa(pool, CategoricalPool{Float64, UInt8})
@test CategoricalValue(pool, 1) isa CategoricalValue{Float64, UInt8}
end

Expand Down
Loading