Skip to content

Commit

Permalink
Fix bug where newaxis with full slices doesn't add new axes (#559)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Aug 22, 2024
1 parent 19edd81 commit 3a5e1db
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
66 changes: 34 additions & 32 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,43 +466,45 @@ def merged_chunk_len_for_indexer(ia, c):
return (c // ia.step) * ia.step

shape = idx.newshape(x.shape)

if shape == x.shape:
# no op case
return x
dtype = x.dtype
chunks = tuple(
chunk_len_for_indexer(ia, c)
for ia, c in zip(idx.args, x.chunksize)
if not isinstance(ia, ndindex.Integer)
)
# no op case (except possibly newaxis applied below)
out = x
else:
dtype = x.dtype
chunks = tuple(
chunk_len_for_indexer(ia, c)
for ia, c in zip(idx.args, x.chunksize)
if not isinstance(ia, ndindex.Integer)
)

target_chunks = normalize_chunks(chunks, shape, dtype=dtype)
target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem
# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem

out = map_direct(
_read_index_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
)
out = map_direct(
_read_index_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
)

# merge chunks for any dims with step > 1 so they are
# the same size as the input (or slightly smaller due to rounding)
merged_chunks = tuple(
merged_chunk_len_for_indexer(ia, c)
for ia, c in zip(idx.args, x.chunksize)
if not isinstance(ia, ndindex.Integer)
)
if chunks != merged_chunks:
out = merge_chunks(out, merged_chunks)
# merge chunks for any dims with step > 1 so they are
# the same size as the input (or slightly smaller due to rounding)
merged_chunks = tuple(
merged_chunk_len_for_indexer(ia, c)
for ia, c in zip(idx.args, x.chunksize)
if not isinstance(ia, ndindex.Integer)
)
if chunks != merged_chunks:
out = merge_chunks(out, merged_chunks)

for axis in where_newaxis:
from cubed.array_api.manipulation_functions import expand_dims
Expand Down
5 changes: 5 additions & 0 deletions cubed/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def spec(tmp_path):
[6, 7, 2, 9, 10],
([6, 7, 2, 9, 10], xp.newaxis),
(xp.newaxis, [6, 7, 2, 9, 10]),
(slice(None), xp.newaxis),
(xp.newaxis, slice(None)),
],
)
def test_int_array_index_1d(spec, ind):
Expand All @@ -36,6 +38,9 @@ def test_int_array_index_1d(spec, ind):
(xp.newaxis, slice(None), [2, 1]),
(slice(None), xp.newaxis, [2, 1]),
(slice(None), [2, 1], xp.newaxis),
(xp.newaxis, slice(None), slice(None)),
(slice(None), xp.newaxis, slice(None)),
(slice(None), slice(None), xp.newaxis),
],
)
def test_int_array_index_2d(spec, ind):
Expand Down

0 comments on commit 3a5e1db

Please sign in to comment.