diff --git a/cubed/core/ops.py b/cubed/core/ops.py index a4850596..e589c4e2 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -466,43 +466,45 @@ def merged_chunk_len_for_indexer(ia, c): return (c // ia.step) * ia.step shape = idx.newshape(x.shape) + if shape == x.shape: - # no op case - return x - dtype = x.dtype - chunks = tuple( - chunk_len_for_indexer(ia, c) - for ia, c in zip(idx.args, x.chunksize) - if not isinstance(ia, ndindex.Integer) - ) + # no op case (except possibly newaxis applied below) + out = x + else: + dtype = x.dtype + chunks = tuple( + chunk_len_for_indexer(ia, c) + for ia, c in zip(idx.args, x.chunksize) + if not isinstance(ia, ndindex.Integer) + ) - target_chunks = normalize_chunks(chunks, shape, dtype=dtype) + target_chunks = normalize_chunks(chunks, shape, dtype=dtype) - # memory allocated by reading one chunk from input array - # note that although the output chunk will overlap multiple input chunks, zarr will - # read the chunks in series, reusing the buffer - extra_projected_mem = x.chunkmem + # memory allocated by reading one chunk from input array + # note that although the output chunk will overlap multiple input chunks, zarr will + # read the chunks in series, reusing the buffer + extra_projected_mem = x.chunkmem - out = map_direct( - _read_index_chunk, - x, - shape=shape, - dtype=dtype, - chunks=target_chunks, - extra_projected_mem=extra_projected_mem, - target_chunks=target_chunks, - selection=selection, - ) + out = map_direct( + _read_index_chunk, + x, + shape=shape, + dtype=dtype, + chunks=target_chunks, + extra_projected_mem=extra_projected_mem, + target_chunks=target_chunks, + selection=selection, + ) - # merge chunks for any dims with step > 1 so they are - # the same size as the input (or slightly smaller due to rounding) - merged_chunks = tuple( - merged_chunk_len_for_indexer(ia, c) - for ia, c in zip(idx.args, x.chunksize) - if not isinstance(ia, ndindex.Integer) - ) - if chunks != merged_chunks: - out = merge_chunks(out, merged_chunks) + # merge chunks for any dims with step > 1 so they are + # the same size as the input (or slightly smaller due to rounding) + merged_chunks = tuple( + merged_chunk_len_for_indexer(ia, c) + for ia, c in zip(idx.args, x.chunksize) + if not isinstance(ia, ndindex.Integer) + ) + if chunks != merged_chunks: + out = merge_chunks(out, merged_chunks) for axis in where_newaxis: from cubed.array_api.manipulation_functions import expand_dims diff --git a/cubed/tests/test_indexing.py b/cubed/tests/test_indexing.py index efa8c6cf..2da1622d 100644 --- a/cubed/tests/test_indexing.py +++ b/cubed/tests/test_indexing.py @@ -20,6 +20,8 @@ def spec(tmp_path): [6, 7, 2, 9, 10], ([6, 7, 2, 9, 10], xp.newaxis), (xp.newaxis, [6, 7, 2, 9, 10]), + (slice(None), xp.newaxis), + (xp.newaxis, slice(None)), ], ) def test_int_array_index_1d(spec, ind): @@ -36,6 +38,9 @@ def test_int_array_index_1d(spec, ind): (xp.newaxis, slice(None), [2, 1]), (slice(None), xp.newaxis, [2, 1]), (slice(None), [2, 1], xp.newaxis), + (xp.newaxis, slice(None), slice(None)), + (slice(None), xp.newaxis, slice(None)), + (slice(None), slice(None), xp.newaxis), ], ) def test_int_array_index_2d(spec, ind):