Skip to content

Commit eacda8d

Browse files
More extensive orthogonal indexing in get/setitem (#1333)
* More extensive orthogonal indexing in get/setitem Added pass-through to orthogonal indexing for the following cases: * index is iterable of integers * index is iterable of length ndim, with each element being a slice, integer, or list. Maximum one list. * Add test cases for indexing with single integer iterable --------- Co-authored-by: Josh Moore <josh@openmicroscopy.org>
1 parent c77f9cd commit eacda8d

File tree

5 files changed

+237
-18
lines changed

5 files changed

+237
-18
lines changed

docs/release.rst

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@ Release notes
66
# to document your changes. On releases it will be
77
# re-indented so that it does not show up in the notes.
88
9-
.. _unreleased:
9+
.. _unreleased:
1010

11-
Unreleased
12-
----------
11+
Unreleased
12+
----------
1313

1414
..
1515
# .. warning::
1616
# Pre-release! Use :command:`pip install --pre zarr` to evaluate this release.
1717
18+
* Implement more extensive fallback of getitem/setitem for orthogonal indexing.
19+
By :user:`Andreas Albert <AndreasAlbertQC>` :issue:`1029`.
20+
1821
.. _release_2.14.2:
1922

2023
2.14.2

docs/tutorial.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,13 @@ For convenience, the orthogonal indexing functionality is also available via the
634634
Any combination of integer, slice, 1D integer array and/or 1D Boolean array can
635635
be used for orthogonal indexing.
636636

637+
If the index contains at most one iterable, and otherwise contains only slices and integers,
638+
orthogonal indexing is also available directly on the array:
639+
640+
>>> z = zarr.array(np.arange(15).reshape(3, 5))
641+
>>> all(z.oindex[[0, 2], :] == z[[0, 2], :])
642+
True
643+
637644
Indexing fields in structured arrays
638645
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
639646

zarr/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
err_too_many_indices,
2929
is_contiguous_selection,
3030
is_pure_fancy_indexing,
31+
is_pure_orthogonal_indexing,
3132
is_scalar,
3233
pop_fields,
3334
)
@@ -817,6 +818,8 @@ def __getitem__(self, selection):
817818
fields, pure_selection = pop_fields(selection)
818819
if is_pure_fancy_indexing(pure_selection, self.ndim):
819820
result = self.vindex[selection]
821+
elif is_pure_orthogonal_indexing(pure_selection, self.ndim):
822+
result = self.get_orthogonal_selection(pure_selection, fields=fields)
820823
else:
821824
result = self.get_basic_selection(pure_selection, fields=fields)
822825
return result
@@ -1387,6 +1390,8 @@ def __setitem__(self, selection, value):
13871390
fields, pure_selection = pop_fields(selection)
13881391
if is_pure_fancy_indexing(pure_selection, self.ndim):
13891392
self.vindex[selection] = value
1393+
elif is_pure_orthogonal_indexing(pure_selection, self.ndim):
1394+
self.set_orthogonal_selection(pure_selection, value, fields=fields)
13901395
else:
13911396
self.set_basic_selection(pure_selection, value, fields=fields)
13921397

zarr/indexing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,26 @@ def is_pure_fancy_indexing(selection, ndim):
101101
)
102102

103103

104+
def is_pure_orthogonal_indexing(selection, ndim):
105+
if not ndim:
106+
return False
107+
108+
# Case 1: Selection is a single iterable of integers
109+
if is_integer_list(selection) or is_integer_array(selection, ndim=1):
110+
return True
111+
112+
# Case two: selection contains either zero or one integer iterables.
113+
# All other selection elements are slices or integers
114+
return (
115+
isinstance(selection, tuple) and len(selection) == ndim and
116+
sum(is_integer_list(elem) or is_integer_array(elem) for elem in selection) <= 1 and
117+
all(
118+
is_integer_list(elem) or is_integer_array(elem)
119+
or isinstance(elem, slice) or isinstance(elem, int) for
120+
elem in selection)
121+
)
122+
123+
104124
def normalize_integer_selection(dim_sel, dim_len):
105125

106126
# normalize type to int

zarr/tests/test_indexing.py

Lines changed: 199 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,6 @@ def test_get_basic_selection_2d():
283283
for selection in bad_selections:
284284
with pytest.raises(IndexError):
285285
z.get_basic_selection(selection)
286-
with pytest.raises(IndexError):
287-
z[selection]
288286
# check fallback on fancy indexing
289287
fancy_selection = ([0, 1], [0, 1])
290288
np.testing.assert_array_equal(z[fancy_selection], [0, 11])
@@ -317,14 +315,179 @@ def test_fancy_indexing_fallback_on_get_setitem():
317315
)
318316

319317

320-
def test_fancy_indexing_doesnt_mix_with_slicing():
321-
z = zarr.zeros((20, 20))
322-
with pytest.raises(IndexError):
323-
z[[1, 2, 3], :] = 2
324-
with pytest.raises(IndexError):
325-
np.testing.assert_array_equal(
326-
z[[1, 2, 3], :], 0
318+
@pytest.mark.parametrize("index,expected_result",
319+
[
320+
# Single iterable of integers
321+
(
322+
[0, 1],
323+
[[0, 1, 2],
324+
[3, 4, 5]]
325+
),
326+
# List first, then slice
327+
(
328+
([0, 1], slice(None)),
329+
[[0, 1, 2],
330+
[3, 4, 5]]
331+
),
332+
# List first, then slice
333+
(
334+
([0, 1], slice(1, None)),
335+
[[1, 2],
336+
[4, 5]]
337+
),
338+
# Slice first, then list
339+
(
340+
(slice(0, 2), [0, 2]),
341+
[[0, 2],
342+
[3, 5]]
343+
),
344+
# Slices only
345+
(
346+
(slice(0, 2), slice(0, 2)),
347+
[[0, 1],
348+
[3, 4]]
349+
),
350+
# List with repeated index
351+
(
352+
([1, 0, 1], slice(1, None)),
353+
[[4, 5],
354+
[1, 2],
355+
[4, 5]]
356+
),
357+
# 1D indexing
358+
(
359+
([1, 0, 1]),
360+
[
361+
[3, 4, 5],
362+
[0, 1, 2],
363+
[3, 4, 5]
364+
]
365+
)
366+
367+
])
368+
def test_orthogonal_indexing_fallback_on_getitem_2d(index, expected_result):
369+
"""
370+
Tests the orthogonal indexing fallback on __getitem__ for a 2D matrix.
371+
372+
In addition to checking expected behavior, all indexing
373+
is also checked against numpy.
374+
"""
375+
# [0, 1, 2],
376+
# [3, 4, 5],
377+
# [6, 7, 8]
378+
a = np.arange(9).reshape(3, 3)
379+
z = zarr.array(a)
380+
381+
np.testing.assert_array_equal(z[index], a[index], err_msg="Indexing disagrees with numpy")
382+
np.testing.assert_array_equal(z[index], expected_result)
383+
384+
385+
@pytest.mark.parametrize("index,expected_result",
386+
[
387+
# Single iterable of integers
388+
(
389+
[0, 1],
390+
[[[0, 1, 2],
391+
[3, 4, 5],
392+
[6, 7, 8]],
393+
[[9, 10, 11],
394+
[12, 13, 14],
395+
[15, 16, 17]]]
396+
),
397+
# One slice, two integers
398+
(
399+
(slice(0, 2), 1, 1),
400+
[4, 13]
401+
),
402+
# One integer, two slices
403+
(
404+
(slice(0, 2), 1, slice(0, 2)),
405+
[[3, 4], [12, 13]]
406+
),
407+
# Two slices and a list
408+
(
409+
(slice(0, 2), [1, 2], slice(0, 2)),
410+
[[[3, 4], [6, 7]], [[12, 13], [15, 16]]]
411+
),
412+
])
413+
def test_orthogonal_indexing_fallback_on_getitem_3d(index, expected_result):
414+
"""
415+
Tests the orthogonal indexing fallback on __getitem__ for a 3D matrix.
416+
417+
In addition to checking expected behavior, all indexing
418+
is also checked against numpy.
419+
"""
420+
# [[[ 0, 1, 2],
421+
# [ 3, 4, 5],
422+
# [ 6, 7, 8]],
423+
424+
# [[ 9, 10, 11],
425+
# [12, 13, 14],
426+
# [15, 16, 17]],
427+
428+
# [[18, 19, 20],
429+
# [21, 22, 23],
430+
# [24, 25, 26]]]
431+
a = np.arange(27).reshape(3, 3, 3)
432+
z = zarr.array(a)
433+
434+
np.testing.assert_array_equal(z[index], a[index], err_msg="Indexing disagrees with numpy")
435+
np.testing.assert_array_equal(z[index], expected_result)
436+
437+
438+
@pytest.mark.parametrize(
439+
"index,expected_result",
440+
[
441+
# Single iterable of integers
442+
(
443+
[0, 1],
444+
[
445+
[1, 1, 1],
446+
[1, 1, 1],
447+
[0, 0, 0]
448+
]
449+
),
450+
# List and slice combined
451+
(
452+
([0, 1], slice(1, 3)),
453+
[[0, 1, 1],
454+
[0, 1, 1],
455+
[0, 0, 0]]
456+
),
457+
# Index repetition is ignored on setitem
458+
(
459+
([0, 1, 1, 1, 1, 1, 1], slice(1, 3)),
460+
[[0, 1, 1],
461+
[0, 1, 1],
462+
[0, 0, 0]]
463+
),
464+
# Slice with step
465+
(
466+
([0, 2], slice(None, None, 2)),
467+
[[1, 0, 1],
468+
[0, 0, 0],
469+
[1, 0, 1]]
327470
)
471+
]
472+
)
473+
def test_orthogonal_indexing_fallback_on_setitem_2d(index, expected_result):
474+
"""
475+
Tests the orthogonal indexing fallback on __setitem__ for a 3D matrix.
476+
477+
In addition to checking expected behavior, all indexing
478+
is also checked against numpy.
479+
"""
480+
# Slice + fancy index
481+
a = np.zeros((3, 3))
482+
z = zarr.array(a)
483+
z[index] = 1
484+
a[index] = 1
485+
np.testing.assert_array_equal(
486+
z, expected_result
487+
)
488+
np.testing.assert_array_equal(
489+
z, a, err_msg="Indexing disagrees with numpy"
490+
)
328491

329492

330493
def test_fancy_indexing_doesnt_mix_with_implicit_slicing():
@@ -335,12 +498,6 @@ def test_fancy_indexing_doesnt_mix_with_implicit_slicing():
335498
np.testing.assert_array_equal(
336499
z2[[1, 2, 3], [1, 2, 3]], 0
337500
)
338-
with pytest.raises(IndexError):
339-
z2[[1, 2, 3]] = 2
340-
with pytest.raises(IndexError):
341-
np.testing.assert_array_equal(
342-
z2[[1, 2, 3]], 0
343-
)
344501
with pytest.raises(IndexError):
345502
z2[..., [1, 2, 3]] = 2
346503
with pytest.raises(IndexError):
@@ -770,6 +927,33 @@ def test_set_orthogonal_selection_3d():
770927
_test_set_orthogonal_selection_3d(v, a, z, ix0, ix1, ix2)
771928

772929

930+
def test_orthogonal_indexing_fallback_on_get_setitem():
931+
z = zarr.zeros((20, 20))
932+
z[[1, 2, 3], [1, 2, 3]] = 1
933+
np.testing.assert_array_equal(
934+
z[:4, :4],
935+
[
936+
[0, 0, 0, 0],
937+
[0, 1, 0, 0],
938+
[0, 0, 1, 0],
939+
[0, 0, 0, 1],
940+
],
941+
)
942+
np.testing.assert_array_equal(
943+
z[[1, 2, 3], [1, 2, 3]], 1
944+
)
945+
# test broadcasting
946+
np.testing.assert_array_equal(
947+
z[1, [1, 2, 3]], [1, 0, 0]
948+
)
949+
# test 1D fancy indexing
950+
z2 = zarr.zeros(5)
951+
z2[[1, 2, 3]] = 1
952+
np.testing.assert_array_equal(
953+
z2, [0, 1, 1, 1, 0]
954+
)
955+
956+
773957
def _test_get_coordinate_selection(a, z, selection):
774958
expect = a[selection]
775959
actual = z.get_coordinate_selection(selection)

0 commit comments

Comments
 (0)