Skip to content

Commit be3bd25

Browse files
committed
Add __index__ to usm_ndarray
1 parent 502daf3 commit be3bd25

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,12 @@ cdef class usm_ndarray:
521521
"only size-1 arrays can be converted to Python scalars"
522522
)
523523

524+
def __index__(self):
525+
if np.issubdtype(self.dtype, np.integer):
526+
return int(self)
527+
528+
raise IndexError("only integer or boolean arrays are valid indices")
529+
524530
def to_device(self, target_device):
525531
"""
526532
Transfer array to target device

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,21 @@ def test_properties():
117117

118118
@pytest.mark.parametrize("func", [bool, float, int])
119119
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
120-
def test_copy_scalar_with_func(func, shape):
121-
X = dpt.usm_ndarray(shape)
122-
Y = np.arange(1, X.size + 1, dtype=X.dtype)
123-
X.usm_data.copy_from_host(Y.view("|u1"))
120+
@pytest.mark.parametrize("dtype", ["|b1", "|f8", "|i8"])
121+
def test_copy_scalar_with_func(func, shape, dtype):
122+
X = dpt.usm_ndarray(shape, dtype=dtype)
123+
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
124+
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
124125
assert func(X) == func(Y)
125126

126127

127128
@pytest.mark.parametrize("method", ["__bool__", "__float__", "__int__"])
128129
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
129-
def test_copy_scalar_with_method(method, shape):
130-
X = dpt.usm_ndarray(shape)
131-
Y = np.arange(1, X.size + 1, dtype=X.dtype)
132-
X.usm_data.copy_from_host(Y.view("|u1"))
130+
@pytest.mark.parametrize("dtype", ["|b1", "|f8", "|i8"])
131+
def test_copy_scalar_with_method(method, shape, dtype):
132+
X = dpt.usm_ndarray(shape, dtype=dtype)
133+
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
134+
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
133135
assert getattr(X, method)() == getattr(Y, method)()
134136

135137

@@ -141,6 +143,34 @@ def test_copy_scalar_invalid_shape(func, shape):
141143
func(X)
142144

143145

146+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
147+
@pytest.mark.parametrize("index_dtype", ["|i8"])
148+
def test_usm_ndarray_as_index(shape, index_dtype):
149+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
150+
Xnp = np.arange(1, X.size + 1, dtype=index_dtype).reshape(shape)
151+
X.usm_data.copy_from_host(Xnp.reshape(-1).view("|u1"))
152+
Y = np.arange(X.size + 1)
153+
assert Y[X] == Y[1]
154+
155+
156+
@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)])
157+
@pytest.mark.parametrize("index_dtype", ["|i8"])
158+
def test_usm_ndarray_as_index_invalid_shape(shape, index_dtype):
159+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
160+
Y = np.arange(X.size + 1)
161+
with pytest.raises(IndexError):
162+
Y[X]
163+
164+
165+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
166+
@pytest.mark.parametrize("index_dtype", ["|f8"])
167+
def test_usm_ndarray_as_index_invalid_dtype(shape, index_dtype):
168+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
169+
Y = np.arange(X.size + 1)
170+
with pytest.raises(IndexError):
171+
Y[X]
172+
173+
144174
@pytest.mark.parametrize(
145175
"ind",
146176
[

0 commit comments

Comments
 (0)