Skip to content

fix and test mixed CartesianIndex #260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/DiskArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module DiskArrays

using LRUCache: LRUCache, LRU

using Base: tail

# Use the README as the module docs
@doc let
path = joinpath(dirname(@__DIR__), "README.md")
Expand Down
41 changes: 33 additions & 8 deletions src/diskindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ DiskIndex(a, i::Tuple{<:AbstractVector{<:Integer}}, batchstrategy) =
function _resolve_indices(chunks, i, indices_pre::DiskIndex, strategy::BatchStrategy)
inow = first(i)
indices_new, chunksrem = process_index(inow, chunks, strategy)
_resolve_indices(chunksrem, Base.tail(i), merge_index(indices_pre, indices_new), strategy)
_resolve_indices(chunksrem, tail(i), merge_index(indices_pre, indices_new), strategy)
end
# Splat out CartesianIndex as regular indices
function _resolve_indices(
chunks, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy
)
_resolve_indices(chunks, (Tuple(i[1])..., tail(i)...), indices_pre, strategy)
end
_resolve_indices(::Tuple{}, ::Tuple{}, indices::DiskIndex, strategy::BatchStrategy) = indices
# No dimension left in array, only singular indices allowed
Expand All @@ -61,17 +67,25 @@ function _resolve_indices(::Tuple{}, i, indices_pre::DiskIndex, strategy::BatchS
(length(inow) == 1 && only(inow) == 1) || throw(ArgumentError("Trailing indices must be 1"))
indices_new = DiskIndex(size(inow), (), size(inow), (), ())
indices = merge_index(indices_pre, indices_new)
_resolve_indices((), Base.tail(i), indices, strategy)
_resolve_indices((), tail(i), indices, strategy)
end
# Splat out CartesianIndex as regular trailing indices
function _resolve_indices(
::Tuple{}, i::Tuple{<:CartesianIndex}, indices_pre::DiskIndex, strategy::BatchStrategy
)
_resolve_indices((), (Tuple(i[1])..., tail(i)...), indices_pre, strategy)
end
# Still dimensions left, but no indices available
function _resolve_indices(chunks, ::Tuple{}, indices_pre::DiskIndex, strategy::BatchStrategy)
chunksnow = first(chunks)
arraysize_from_chunksize(chunksnow) == 1 || throw(ArgumentError("Indices can only be omitted for trailing singleton dimensions"))
checktrailing(arraysize_from_chunksize(chunksnow))
indices_new = add_dimension_index(strategy)
indices = merge_index(indices_pre, indices_new)
_resolve_indices(Base.tail(chunks), (), indices, strategy)
_resolve_indices(tail(chunks), (), indices, strategy)
end

checktrailing(i) = i == 1 || throw(ArgumentError("Indices can only be omitted for trailing singleton dimensions"))

add_dimension_index(::NoBatch) = DiskIndex((), (1,), (), (1,), (1:1,))
add_dimension_index(::Union{ChunkRead,SubRanges}) = DiskIndex((), (1,), ([()],), ([(1,)],), ([(1:1,)],))

Expand All @@ -98,18 +112,24 @@ Calculate indices for `i` the first chunk/s in `chunks`
Returns a [`DiskIndex`](@ref), and the remaining chunks.
"""
process_index(i, chunks, ::NoBatch) = process_index(i, chunks)
process_index(inow::Integer, chunks) = DiskIndex((), (1,), (), (1,), (inow:inow,)), Base.tail(chunks)
function process_index(i::CartesianIndex{N}, chunks, ::NoBatch) where {N}
_, chunksrem = splitchunks(i, chunks)
di = DiskIndex((), map(one, i.I), (), (1,), map(i -> i:i, i.I))
return di, chunksrem
end
process_index(inow::Integer, chunks) =
DiskIndex((), (1,), (), (1,), (inow:inow,)), tail(chunks)
function process_index(::Colon, chunks)
s = arraysize_from_chunksize(first(chunks))
DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),), Base.tail(chunks)
DiskIndex((s,), (s,), (Colon(),), (Colon(),), (1:s,),), tail(chunks)
end
function process_index(i::AbstractUnitRange{<:Integer}, chunks, ::NoBatch)
DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,)), Base.tail(chunks)
DiskIndex((length(i),), (length(i),), (Colon(),), (Colon(),), (i,)), tail(chunks)
end
function process_index(i::AbstractArray{<:Integer}, chunks, ::NoBatch)
indmin, indmax = isempty(i) ? (1, 0) : extrema(i)
di = DiskIndex(size(i), ((indmax - indmin + 1),), map(_ -> Colon(), size(i)), ((i .- (indmin - 1)),), (indmin:indmax,))
return di, Base.tail(chunks)
return di, tail(chunks)
end
function process_index(i::AbstractArray{Bool,N}, chunks, ::NoBatch) where {N}
chunksnow, chunksrem = splitchunks(i, chunks)
Expand Down Expand Up @@ -162,7 +182,12 @@ splitchunks(i::CartesianIndex, chunks) = splitchunks(i.I, (), chunks)
splitchunks(_, chunks) = (first(chunks),), Base.tail(chunks)
splitchunks(si, chunksnow, chunksrem) =
splitchunks(Base.tail(si), (chunksnow..., first(chunksrem)), Base.tail(chunksrem))
function splitchunks(si,chunksnow, ::Tuple{})
only(first(si)) == 1 || throw(ArgumentError("Trailing indices must be 1"))
splitchunks(Base.tail(si), chunksnow, ())
end
splitchunks(::Tuple{}, chunksnow, chunksrem) = (chunksnow, chunksrem)
splitchunks(::Tuple{}, chunksnow, chunksrem::Tuple{}) = (chunksnow, chunksrem)

"""
output_aliasing(di::DiskIndex, ndims_dest, ndims_source)
Expand Down
10 changes: 5 additions & 5 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ function getindex_disk(a::AbstractArray, i::Integer)
return only(outputarray)
end
getindex_disk(a::AbstractArray, i...) = getindex_disk!(nothing, a, i...)
getindex_disk(a::AbstractArray, i::ChunkIndex{<:Any,OneBasedChunks}) =
a[eachchunk(a)[i.I]...]
getindex_disk(a::AbstractArray, i::ChunkIndex{<:Any,OffsetChunks}) =
wrapchunk(a[nooffset(i)], eachchunk(a)[i.I])

function getindex_disk!(out::Union{Nothing,AbstractArray}, a::AbstractArray, i...)
# Check if we can write once or need to use multiple batches
Expand Down Expand Up @@ -202,7 +206,7 @@ end
Generate an `Array` to pass to `readblock!`
"""
function create_outputarray(out::AbstractArray, a::AbstractArray, output_size::Tuple)
size(out) == output_size || throw(ArgumentError("Expected output array size of $output_size"))
size(out) == output_size || throw(ArgumentError("Expected output array size of $output_size, got $(size(out))"))
return out
end
create_outputarray(::Nothing, a::AbstractArray, output_size::Tuple) =
Expand Down Expand Up @@ -306,10 +310,6 @@ macro implement_getindex(t)
quote
DiskArrays.isdisk(::Type{<:$t}) = true
Base.getindex(a::$t, i...) = getindex_disk(a, i...)
@inline Base.getindex(a::$t, i::ChunkIndex{<:Any,OneBasedChunks}) =
a[eachchunk(a)[i.I]...]
@inline Base.getindex(a::$t, i::ChunkIndex{<:Any,OffsetChunks}) =
wrapchunk(a[nooffset(i)], eachchunk(a)[i.I])
function DiskArrays.ChunkIndices(a::$t; offset=false)
return ChunkIndices(
map(s -> 1:s, size(eachchunk(a))), offset ? OffsetChunks() : OneBasedChunks()
Expand Down
16 changes: 10 additions & 6 deletions src/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,20 @@ function eachchunk_view(::Chunked, vv)
end
eachchunk_view(::Unchunked, a) = estimate_chunksize(a)

# Implementaion macro
function view_disk(A, I...)
@inline
# Modified from Base.view
J = to_indices(A, I)
@boundscheck checkbounds(A, J...)
J′ = Base.rm_singleton_indices(ntuple(Returns(true), Val(ndims(A))), J...)
SubDiskArray(Base.unsafe_view(A, J′...))
end

# Implementaion macro
macro implement_subarray(t)
t = esc(t)
quote
function Base.view(a::$t, i...)
i2 = _replace_colon.(size(a), i)
return SubDiskArray(SubArray(a, i2))
end
Base.view(a::$t, i::CartesianIndices) = view(a, i.indices...)
@inline Base.view(a::$t, i...) = view_disk(a, i...)
Base.vec(a::$t) = view(a, :)
end
end
16 changes: 13 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ end
@test a[CartesianIndex(1, 2), 3] == 15
@test a[CartesianIndex(1, 2, 3)] == 15
end

@testset "isdisk" begin
a = reshape(1:24, 2, 3, 4)
da = AccessCountDiskArray(a; chunksize=(2, 2, 2))
Expand All @@ -51,16 +52,22 @@ end

function test_getindex(a)
@test a[2, 3, 1] == 10
@test a[CartesianIndex(2, 3), 1] == 10
@test a[2, CartesianIndex(3,), 1] == 10
@test a[CartesianIndex(2, 3, 1)] == 10
@test a[1:2, CartesianIndex(3, 1, 1)] == 9:10
@test a[2, 3] == 10
@test a[CartesianIndex(2, 3)] == 10
@test a[2, 3, 1, 1] == 10
@test a[:, 1] == [1, 2, 3, 4]
@test a[1:2, 1:2, 1, 1] == [1 5; 2 6]
@test a[end:-1:1, 1, 1] == [4, 3, 2, 1]
@test a[2, 3, 1, 1:1] == [10]
@test a[2, 3, 1, [1], [1]] == fill(10, 1, 1)
@test a[:, 3, 1, [1]] == reshape(9:12, 4, 1)
@test a[:, CartesianIndex(3, 1), [1]] == reshape(9:12, 4, 1)
@test a[CartesianIndices((1:2, 1:2)), 1] == [1 5; 2 6]
@test getindex_count(a) == 10
@test getindex_count(a) == 16
# Test bitmask indexing
m = falses(4, 5, 1)
m[2, [1, 2, 3, 5], 1] .= true
Expand All @@ -73,6 +80,7 @@ function test_getindex(a)
@test a[2:4:14] == [2, 6, 10, 14]
# Test that readblock was called exactly onces for every getindex
@test a[2:2:4, 1:2:5] == [2 10 18; 4 12 20]
@test a[2:2:4, 1:2:5] == [2 10 18; 4 12 20]
@test a[[1, 3, 4], [1, 3], 1] == [1 9; 3 11; 4 12]
@testset "allowscalar" begin
DiskArrays.allowscalar(false)
Expand All @@ -85,7 +93,7 @@ function test_getindex(a)
end

function test_setindex(a)
a[1, 1, 1] = 1
a[CartesianIndex(1, 1), 1] = 1
a[1, 2] = 2
a[1, 3, 1, 1] = 3
a[2:2, :] = [1, 2, 3, 4, 5]
Expand Down Expand Up @@ -117,10 +125,12 @@ function test_view(a)
v[1:2, 1] = [1, 2]
v[1:2, 2:3] = [4 4; 4 4]
@test v[1:2, 1] == [1, 2]
@test v[1:2, CartesianIndex(1,)] == [1, 2]
@test v[1:2, CartesianIndex(1, 1)] == [1, 2]
@test v[1:2, 2:3] == [4 4; 4 4]
@test trueparent(a)[2:3, 2] == [1, 2]
@test trueparent(a)[2:3, 3:4] == [4 4; 4 4]
@test getindex_count(a) == 2
@test getindex_count(a) == 4
@test setindex_count(a) == 2

v2 = view(a, 2:3, 2:4, Int[])
Expand Down
Loading