Skip to content

Commit 55c5990

Browse files
committed
Implement support of tuple key in __getitem__ and __setitem__
1 parent d41ed51 commit 55c5990

File tree

2 files changed

+81
-5
lines changed

2 files changed

+81
-5
lines changed

dpnp/dpnp_array.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@
2929

3030
import dpnp
3131

32+
33+
def _get_unwrapped_index_key(key):
34+
"""
35+
Return a key where each nested instance of DPNP array is unwrapped into USM ndarray
36+
for futher processing in DPCTL advanced indexing functions.
37+
38+
"""
39+
40+
if isinstance(key, tuple):
41+
if any(isinstance(x, dpnp_array) for x in key):
42+
# create a new tuple from the input key with unwrapped DPNP arrays
43+
return tuple(x.get_array() if isinstance(x, dpnp_array) else x for x in key)
44+
elif isinstance(key, dpnp_array):
45+
return key.get_array()
46+
return key
47+
48+
3249
class dpnp_array:
3350
"""
3451
Multi-dimensional array object.
@@ -176,8 +193,7 @@ def __ge__(self, other):
176193
# '__getattribute__',
177194

178195
def __getitem__(self, key):
179-
if isinstance(key, dpnp_array):
180-
key = key.get_array()
196+
key = _get_unwrapped_index_key(key)
181197

182198
item = self._array_obj.__getitem__(key)
183199
if not isinstance(item, dpt.usm_ndarray):
@@ -337,8 +353,8 @@ def __rxor__(self, other):
337353
# '__setattr__',
338354

339355
def __setitem__(self, key, val):
340-
if isinstance(key, dpnp_array):
341-
key = key.get_array()
356+
key = _get_unwrapped_index_key(key)
357+
342358
if isinstance(val, dpnp_array):
343359
val = val.get_array()
344360

tests/test_indexing.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,70 @@
66

77
import numpy
88
from numpy.testing import (
9-
assert_array_equal
9+
assert_,
10+
assert_array_equal,
11+
assert_equal
1012
)
1113

1214

15+
class TestIndexing:
16+
def test_ellipsis_index(self):
17+
a = dpnp.array([[1, 2, 3],
18+
[4, 5, 6],
19+
[7, 8, 9]])
20+
assert_(a[...] is not a)
21+
assert_equal(a[...], a)
22+
23+
# test that slicing with ellipsis doesn't skip an arbitrary number of dimensions
24+
assert_equal(a[0, ...], a[0])
25+
assert_equal(a[0, ...], a[0,:])
26+
assert_equal(a[..., 0], a[:, 0])
27+
28+
# test that slicing with ellipsis always results in an array
29+
assert_equal(a[0, ..., 1], dpnp.array(2))
30+
31+
# assignment with `(Ellipsis,)` on 0-d arrays
32+
b = dpnp.array(1)
33+
b[(Ellipsis,)] = 2
34+
assert_equal(b, 2)
35+
36+
def test_boolean_indexing_list(self):
37+
a = dpnp.array([1, 2, 3])
38+
b = dpnp.array([True, False, True])
39+
40+
assert_equal(a[b], [1, 3])
41+
assert_equal(a[None, b], [[1, 3]])
42+
43+
def test_indexing_array_weird_strides(self):
44+
np_x = numpy.ones(10)
45+
dp_x = dpnp.ones(10)
46+
47+
np_ind = numpy.arange(10)[:, None, None, None]
48+
np_ind = numpy.broadcast_to(np_ind, (10, 55, 4, 4))
49+
50+
dp_ind = dpnp.arange(10)[:, None, None, None]
51+
dp_ind = dpnp.broadcast_to(dp_ind, (10, 55, 4, 4))
52+
53+
# single advanced index case
54+
assert_array_equal(dp_x[dp_ind], np_x[np_ind])
55+
56+
np_x2 = numpy.ones((10, 2))
57+
dp_x2 = dpnp.ones((10, 2))
58+
59+
np_zind = numpy.zeros(4, dtype=numpy.intp)
60+
dp_zind = dpnp.asarray(np_zind)
61+
62+
# higher dimensional advanced index
63+
assert_array_equal(dp_x2[dp_ind, dp_zind], np_x2[np_ind, np_zind])
64+
65+
def test_indexing_array_negative_strides(self):
66+
arr = dpnp.zeros((4, 4))[::-1, ::-1]
67+
68+
slices = (slice(None), dpnp.array([0, 1, 2, 3]))
69+
arr[slices] = 10
70+
assert_array_equal(arr, 10.)
71+
72+
1373
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
1474
def test_choose():
1575
a = numpy.r_[:4]

0 commit comments

Comments
 (0)