Skip to content

Commit cf9935c

Browse files
densmirnoleksandr-pavlyk
authored andcommitted
Add __index__ to usm_ndarray
1 parent ab701e1 commit cf9935c

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,12 @@ cdef class usm_ndarray:
521521
"only size-1 arrays can be converted to Python scalars"
522522
)
523523

524+
def __index__(self):
525+
if np.issubdtype(self.dtype, np.integer):
526+
return int(self)
527+
528+
raise IndexError("only integer or boolean arrays are valid indices")
529+
524530
def to_device(self, target_device):
525531
"""
526532
Transfer array to target device

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,19 +116,21 @@ def test_properties():
116116

117117
@pytest.mark.parametrize("func", [bool, float, int])
118118
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
119-
def test_copy_scalar_with_func(func, shape):
120-
X = dpt.usm_ndarray(shape)
121-
Y = np.arange(1, X.size + 1, dtype=X.dtype)
122-
X.usm_data.copy_from_host(Y.view("|u1"))
119+
@pytest.mark.parametrize("dtype", ["|b1", "|f8", "|i8"])
120+
def test_copy_scalar_with_func(func, shape, dtype):
121+
X = dpt.usm_ndarray(shape, dtype=dtype)
122+
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
123+
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
123124
assert func(X) == func(Y)
124125

125126

126127
@pytest.mark.parametrize("method", ["__bool__", "__float__", "__int__"])
127128
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
128-
def test_copy_scalar_with_method(method, shape):
129-
X = dpt.usm_ndarray(shape)
130-
Y = np.arange(1, X.size + 1, dtype=X.dtype)
131-
X.usm_data.copy_from_host(Y.view("|u1"))
129+
@pytest.mark.parametrize("dtype", ["|b1", "|f8", "|i8"])
130+
def test_copy_scalar_with_method(method, shape, dtype):
131+
X = dpt.usm_ndarray(shape, dtype=dtype)
132+
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
133+
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
132134
assert getattr(X, method)() == getattr(Y, method)()
133135

134136

@@ -140,6 +142,34 @@ def test_copy_scalar_invalid_shape(func, shape):
140142
func(X)
141143

142144

145+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
146+
@pytest.mark.parametrize("index_dtype", ["|i8"])
147+
def test_usm_ndarray_as_index(shape, index_dtype):
148+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
149+
Xnp = np.arange(1, X.size + 1, dtype=index_dtype).reshape(shape)
150+
X.usm_data.copy_from_host(Xnp.reshape(-1).view("|u1"))
151+
Y = np.arange(X.size + 1)
152+
assert Y[X] == Y[1]
153+
154+
155+
@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)])
156+
@pytest.mark.parametrize("index_dtype", ["|i8"])
157+
def test_usm_ndarray_as_index_invalid_shape(shape, index_dtype):
158+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
159+
Y = np.arange(X.size + 1)
160+
with pytest.raises(IndexError):
161+
Y[X]
162+
163+
164+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
165+
@pytest.mark.parametrize("index_dtype", ["|f8"])
166+
def test_usm_ndarray_as_index_invalid_dtype(shape, index_dtype):
167+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
168+
Y = np.arange(X.size + 1)
169+
with pytest.raises(IndexError):
170+
Y[X]
171+
172+
143173
@pytest.mark.parametrize(
144174
"ind",
145175
[

0 commit comments

Comments
 (0)