Skip to content

Commit 4a0dc78

Browse files
authored
fix linear indexing of splitdimsview (#40)
1 parent 9deb383 commit 4a0dc78

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/splitdims.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ end
7373
end
7474

7575
@generated function slice_inds(i::CartesianIndex, ::Val{dims}, ::Val{n}) where {dims, n}
76+
@assert length(i) == length(dims)
7677
out = []
7778
for j in 1:n
7879
k = findfirst(==(j), dims)
@@ -96,7 +97,7 @@ Base.parent(a::SplitDimsArray) = a.parent
9697
axes(a::SplitDimsArray{T, N, Dims}) where {T, N, Dims} = getindices(axes(parent(a)), Dims)
9798
size(a::SplitDimsArray{T, N, Dims}) where {T, N, Dims} = getindices(size(parent(a)), Dims)
9899
Base.IndexStyle(::SplitDimsArray) = Base.IndexCartesian()
99-
@propagate_inbounds function Base.getindex(a::SplitDimsArray{T, N, Dims}, i::Int...) where {T, N, Dims}
100+
@propagate_inbounds function Base.getindex(a::SplitDimsArray{T, N, Dims}, i::Vararg{Int, N}) where {T, N, Dims}
100101
return view(parent(a), slice_inds(CartesianIndex(i), Val(Dims), Val(ndims(parent(a))))...)
101102
end
102103

test/splitdims.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,21 @@ end
3434
@test splitdimsview([1 2; 3 4], (1, 2)) == tmp
3535
tmp[1,2][] = 3; tmp[2,1][] = 2
3636
@test splitdimsview([1 2; 3 4], (2, 1)) == tmp
37+
# indexing
38+
@test splitdimsview([1 2; 3 4], (1, 2))[2, 1] == fill(3)
39+
@test splitdimsview([1 2; 3 4], (1, 2))[2] == fill(3)
40+
@test splitdimsview([1 2; 3 4], (1, 2))[CartesianIndex(2)] == fill(3)
3741

3842
# Vector
3943
@test splitdimsview([1,2,3]) == [fill(1, ()), fill(2, ()), fill(3, ())]
4044
@test splitdimsview([1,2,3], (1,)) == [fill(1, ()), fill(2, ()), fill(3, ())]
4145
@test splitdimsview([1,2,3], ()) == fill([1,2,3], ())
46+
# indexing
47+
@test splitdimsview([1,2,3], (1,))[2] == fill(2)
48+
@test splitdimsview([1,2,3], (1,))[CartesianIndex(2)] == fill(2)
4249

4350
# Array{0}
4451
@test splitdimsview(fill(1, ())) == fill(fill(1, ()), ())
52+
# indexing
53+
@test splitdimsview(fill(1, ()))[] == fill(1)
4554
end

0 commit comments

Comments
 (0)