Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/acquire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,8 @@ end
_acquire_impl!(pool, T, dims...)
end

# Fallback for nothing pool
@inline _acquire_impl!(::Nothing, ::Type{T}, n::Int) where {T} = Vector{T}(undef, n)
@inline _acquire_impl!(::Nothing, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = Array{T, N}(undef, dims)
@inline _acquire_impl!(::Nothing, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = Array{T, N}(undef, dims)

# Similar-style
@inline _acquire_impl!(pool::AdaptiveArrayPool, x::AbstractArray) = _acquire_impl!(pool, eltype(x), size(x))
@inline _acquire_impl!(::Nothing, x::AbstractArray) = similar(x)

"""
_unsafe_acquire_impl!(pool, Type{T}, dims...) -> Array{T,N}
Expand All @@ -248,14 +242,8 @@ end
return get_nd_array!(tp, dims)
end

# Fallback for nothing pool
@inline _unsafe_acquire_impl!(::Nothing, ::Type{T}, n::Int) where {T} = Vector{T}(undef, n)
@inline _unsafe_acquire_impl!(::Nothing, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = Array{T, N}(undef, dims)
@inline _unsafe_acquire_impl!(::Nothing, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = Array{T, N}(undef, dims)

# Similar-style
@inline _unsafe_acquire_impl!(pool::AdaptiveArrayPool, x::AbstractArray) = _unsafe_acquire_impl!(pool, eltype(x), size(x))
@inline _unsafe_acquire_impl!(::Nothing, x::AbstractArray) = similar(x)

# ==============================================================================
# Acquisition API (User-facing with untracked marking)
Expand Down
33 changes: 12 additions & 21 deletions src/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@ function checkpoint!(pool::AdaptiveArrayPool)
push!(pool._untracked_flags, false)
depth = pool._current_depth

# Fixed slots - direct field access, no Dict lookup
_checkpoint_typed_pool!(pool.float64, depth)
_checkpoint_typed_pool!(pool.float32, depth)
_checkpoint_typed_pool!(pool.int64, depth)
_checkpoint_typed_pool!(pool.int32, depth)
_checkpoint_typed_pool!(pool.complexf64, depth)
_checkpoint_typed_pool!(pool.bool, depth)
# Fixed slots - zero allocation via @generated iteration
foreach_fixed_slot(pool) do tp
_checkpoint_typed_pool!(tp, depth)
end

# Others - iterate without allocation (values() returns iterator)
for p in values(pool.others)
Expand Down Expand Up @@ -57,13 +54,10 @@ See also: [`checkpoint!`](@ref), [`@with_pool`](@ref)
function rewind!(pool::AdaptiveArrayPool)
cur_depth = pool._current_depth

# Process fixed slots directly (zero allocation)
_rewind_typed_pool!(pool.float64, cur_depth)
_rewind_typed_pool!(pool.float32, cur_depth)
_rewind_typed_pool!(pool.int64, cur_depth)
_rewind_typed_pool!(pool.int32, cur_depth)
_rewind_typed_pool!(pool.complexf64, cur_depth)
_rewind_typed_pool!(pool.bool, cur_depth)
# Fixed slots - zero allocation via @generated iteration
foreach_fixed_slot(pool) do tp
_rewind_typed_pool!(tp, cur_depth)
end

# Process fallback types
for tp in values(pool.others)
Expand Down Expand Up @@ -277,13 +271,10 @@ empty!(pool) # Release all memory
Any SubArrays previously acquired from this pool become invalid after `empty!`.
"""
function Base.empty!(pool::AdaptiveArrayPool)
# Fixed slots
empty!(pool.float64)
empty!(pool.float32)
empty!(pool.int64)
empty!(pool.int32)
empty!(pool.complexf64)
empty!(pool.bool)
# Fixed slots - zero allocation via @generated iteration
foreach_fixed_slot(pool) do tp
empty!(tp)
end

# Others - clear all TypedPools then the IdDict itself
for tp in values(pool.others)
Expand Down
47 changes: 37 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,28 @@ TypedPool{T}() where {T} = TypedPool{T}(
)

# ==============================================================================
# AdaptiveArrayPool
# Fixed Slot Configuration
# ==============================================================================

"""
AdaptiveArrayPool
FIXED_SLOT_FIELDS

Field names for fixed slot TypedPools. Single source of truth for `foreach_fixed_slot`.

When modifying, also update: struct definition, `get_typed_pool!` dispatches, constructor.
Tests verify synchronization automatically.
"""
const FIXED_SLOT_FIELDS = (:float64, :float32, :int64, :int32, :complexf64, :complexf32, :bool)

A high-performance memory pool supporting multiple data types.
# ==============================================================================
# AdaptiveArrayPool
# ==============================================================================

## Features
- **Fixed Slots**: `Float64`, `Float32`, `Int64`, `Int32`, `ComplexF64`, `Bool` have dedicated fields (zero Dict lookup)
- **Fallback**: Other types use `IdDict` (still fast, but with lookup overhead)
- **Zero Allocation**: `checkpoint!/rewind!` use internal stacks, no allocation after warmup
- **Untracked Detection**: `_current_depth` and `_untracked_flags` track acquire calls from inner functions
"""
AdaptiveArrayPool

## Thread Safety
This pool is **NOT thread-safe**. Use one pool per Task via `get_task_local_pool()`.
Multi-type memory pool with fixed slots for common types and IdDict fallback for others.
Zero allocation after warmup. NOT thread-safe - use one pool per Task.
"""
mutable struct AdaptiveArrayPool
# Fixed Slots: common types with zero lookup overhead
Expand All @@ -159,6 +165,7 @@ mutable struct AdaptiveArrayPool
int64::TypedPool{Int64}
int32::TypedPool{Int32}
complexf64::TypedPool{ComplexF64}
complexf32::TypedPool{ComplexF32}
bool::TypedPool{Bool}

# Fallback: rare types
Expand All @@ -176,6 +183,7 @@ function AdaptiveArrayPool()
TypedPool{Int64}(),
TypedPool{Int32}(),
TypedPool{ComplexF64}(),
TypedPool{ComplexF32}(),
TypedPool{Bool}(),
IdDict{DataType, Any}(),
1, # _current_depth: 1 = global scope (sentinel)
Expand All @@ -193,6 +201,7 @@ end
@inline get_typed_pool!(p::AdaptiveArrayPool, ::Type{Int64}) = p.int64
@inline get_typed_pool!(p::AdaptiveArrayPool, ::Type{Int32}) = p.int32
@inline get_typed_pool!(p::AdaptiveArrayPool, ::Type{ComplexF64}) = p.complexf64
@inline get_typed_pool!(p::AdaptiveArrayPool, ::Type{ComplexF32}) = p.complexf32
@inline get_typed_pool!(p::AdaptiveArrayPool, ::Type{Bool}) = p.bool

# Slow Path: rare types via IdDict
Expand All @@ -208,3 +217,21 @@ end
tp
end::TypedPool{T}
end

# ==============================================================================
# Zero-Allocation Iteration
# ==============================================================================

"""
foreach_fixed_slot(f, pool::AdaptiveArrayPool)

Apply `f` to each fixed slot TypedPool. Zero allocation via compile-time unrolling.
"""
@generated function foreach_fixed_slot(f::F, pool::AdaptiveArrayPool) where {F}
exprs = [:(f(getfield(pool, $(QuoteNode(field))))) for field in FIXED_SLOT_FIELDS]
quote
Base.@_inline_meta
$(exprs...)
nothing
end
end
60 changes: 30 additions & 30 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,25 @@ function _check_pointer_overlap(arr::Array, pool::AdaptiveArrayPool)
arr_len = length(arr) * sizeof(eltype(arr))
arr_end = arr_ptr + arr_len

# Check fixed slots
for tp in (pool.float64, pool.float32, pool.int64, pool.int32, pool.complexf64, pool.bool)
check_overlap = function(tp)
for v in tp.vectors
v_ptr = UInt(pointer(v))
v_len = length(v) * sizeof(eltype(v))
v_end = v_ptr + v_len
# Check memory range overlap
if !(arr_end <= v_ptr || v_end <= arr_ptr)
error("Safety Violation: The function returned an Array backed by pool memory. This is unsafe as the memory will be reclaimed. Please return a copy (collect) or a scalar.")
end
end
end

# Check fixed slots
foreach_fixed_slot(pool) do tp
check_overlap(tp)
end

# Check others
for tp in values(pool.others)
for v in tp.vectors
v_ptr = UInt(pointer(v))
v_len = length(v) * sizeof(eltype(v))
v_end = v_ptr + v_len
if !(arr_end <= v_ptr || v_end <= arr_ptr)
error("Safety Violation: The function returned an Array backed by pool memory. This is unsafe as the memory will be reclaimed. Please return a copy (collect) or a scalar.")
end
end
check_overlap(tp)
end
end

Expand Down Expand Up @@ -154,22 +150,14 @@ function pool_stats(pool::AdaptiveArrayPool; io::IO=stdout)
printstyled(io, "AdaptiveArrayPool", bold=true, color=:white)
println(io)

fixed_slots = [
("Float64", pool.float64),
("Float32", pool.float32),
("Int64", pool.int64),
("Int32", pool.int32),
("ComplexF64", pool.complexf64),
("Bool", pool.bool)
]

has_content = false

# Fixed slots
for (name, tp) in fixed_slots
# Fixed slots - use foreach_fixed_slot for consistency
foreach_fixed_slot(pool) do tp
if !isempty(tp.vectors)
has_content = true
pool_stats(tp; io, indent=2, name="$name (fixed)")
T = typeof(tp).parameters[1] # Extract T from TypedPool{T}
pool_stats(tp; io, indent=2, name="$T (fixed)")
end
end

Expand Down Expand Up @@ -221,13 +209,25 @@ end

# Compact one-line show for AdaptiveArrayPool
function Base.show(io::IO, pool::AdaptiveArrayPool)
fixed_slots = (pool.float64, pool.float32, pool.int64, pool.int32, pool.complexf64, pool.bool)
n_types = count(tp -> !isempty(tp.vectors), fixed_slots) + length(pool.others)
total_vectors = sum(length(tp.vectors) for tp in fixed_slots; init=0) +
sum(length(tp.vectors) for tp in values(pool.others); init=0)
total_active = sum(tp.n_active for tp in fixed_slots; init=0) +
sum(tp.n_active for tp in values(pool.others); init=0)
print(io, "AdaptiveArrayPool(types=$n_types, vectors=$total_vectors, active=$total_active)")
n_types = Ref(0)
total_vectors = Ref(0)
total_active = Ref(0)

foreach_fixed_slot(pool) do tp
if !isempty(tp.vectors)
n_types[] += 1
end
total_vectors[] += length(tp.vectors)
total_active[] += tp.n_active
end

n_types[] += length(pool.others)
for tp in values(pool.others)
total_vectors[] += length(tp.vectors)
total_active[] += tp.n_active
end

print(io, "AdaptiveArrayPool(types=$(n_types[]), vectors=$(total_vectors[]), active=$(total_active[]))")
end

# Multi-line show for AdaptiveArrayPool
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ else
include("test_disabled_pooling.jl")
include("test_aliases.jl")
include("test_nway_cache.jl")
include("test_fixed_slots.jl")
end
65 changes: 65 additions & 0 deletions test/test_aliases.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,69 @@ using AdaptiveArrayPools
end
end

@testset "Similar-style _impl! via macro (runtime coverage)" begin
# These tests exercise the _acquire_impl!(pool, x::AbstractArray) and
# _unsafe_acquire_impl!(pool, x::AbstractArray) methods which are only
# called through macro transformation (not public API).

ref_mat = rand(5, 6)
ref_vec = rand(10)
ref_int = rand(Int32, 3, 4)

@testset "acquire!(pool, x) via @with_pool" begin
pool = AdaptiveArrayPool()

result = @with_pool pool begin
# Similar-style acquire - macro transforms to _acquire_impl!(pool, ref_mat)
mat = acquire!(pool, ref_mat)
@test size(mat) == size(ref_mat)
@test eltype(mat) == eltype(ref_mat)
@test mat isa Base.ReshapedArray{Float64, 2}

vec = acquire!(pool, ref_vec)
@test size(vec) == size(ref_vec)
@test vec isa SubArray{Float64, 1}

int_mat = acquire!(pool, ref_int)
@test eltype(int_mat) == Int32
@test size(int_mat) == (3, 4)

sum(mat) + sum(vec) + sum(int_mat)
end
@test result isa Float64
end

@testset "unsafe_acquire!(pool, x) via @with_pool" begin
pool = AdaptiveArrayPool()

result = @with_pool pool begin
# Similar-style unsafe_acquire - macro transforms to _unsafe_acquire_impl!(pool, ref_mat)
mat = unsafe_acquire!(pool, ref_mat)
@test size(mat) == size(ref_mat)
@test mat isa Matrix{Float64}

vec = unsafe_acquire!(pool, ref_vec)
@test size(vec) == size(ref_vec)
@test vec isa Vector{Float64}

sum(mat) + sum(vec)
end
@test result isa Float64
end

@testset "acquire_view!/acquire_array! aliases via @with_pool" begin
pool = AdaptiveArrayPool()

@with_pool pool begin
# acquire_view! is alias for acquire!
v1 = acquire_view!(pool, ref_mat)
@test size(v1) == size(ref_mat)

# acquire_array! is alias for unsafe_acquire!
v2 = acquire_array!(pool, ref_vec)
@test size(v2) == size(ref_vec)
end
end
end

end
9 changes: 7 additions & 2 deletions test/test_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,15 @@
@test pool.int32.n_active == 1

# ComplexF64 - fixed slot
vc = acquire!(pool, ComplexF64, 5)
@test eltype(vc) == ComplexF64
vc64 = acquire!(pool, ComplexF64, 5)
@test eltype(vc64) == ComplexF64
@test pool.complexf64.n_active == 1

# ComplexF32 - fixed slot
vc32 = acquire!(pool, ComplexF32, 5)
@test eltype(vc32) == ComplexF32
@test pool.complexf32.n_active == 1

# Bool - fixed slot
vb = acquire!(pool, Bool, 5)
@test eltype(vb) == Bool
Expand Down
Loading