Skip to content

Commit 34ce6b7

Browse files
Merge pull request #578 from IntelPython/scalar_copy
2 parents b97efdf + 1f34ee8 commit 34ce6b7

File tree

4 files changed

+142
-7
lines changed

4 files changed

+142
-7
lines changed

dpctl/memory/_sycl_usm_array_interface_utils.pxi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
cdef bint _valid_usm_ptr_and_context(DPCTLSyclUSMRef ptr, SyclContext ctx):
44
usm_type = _Memory.get_pointer_type(ptr, ctx)
5-
return usm_type in (b'shared', b'device', b'host')
5+
return usm_type in (b"shared", b"device", b"host")
66

77

88
cdef DPCTLSyclQueueRef _queue_ref_copy_from_SyclQueue(
@@ -49,7 +49,7 @@ cdef DPCTLSyclQueueRef get_queue_ref_from_ptr_and_syclobj(
4949
elif pycapsule.PyCapsule_IsValid(syclobj, "SyclContextRef"):
5050
ctx = <SyclContext>SyclContext(syclobj)
5151
return _queue_ref_copy_from_USMRef_and_SyclContext(ptr, ctx)
52-
elif hasattr(syclobj, '_get_capsule'):
52+
elif hasattr(syclobj, "_get_capsule"):
5353
cap = syclobj._get_capsule()
5454
if pycapsule.PyCapsule_IsValid(cap, "SyclQueueRef"):
5555
q = SyclQueue(cap)
@@ -166,8 +166,8 @@ cdef class _USMBufferData:
166166
nd = len(ary_shape)
167167
try:
168168
dt = np.dtype(ary_typestr)
169-
if (dt.hasobject or not (np.issubdtype(dt.type, np.integer) or
170-
np.issubdtype(dt.type, np.inexact))):
169+
if (dt.hasobject or not (np.issubdtype(dt.type, np.number) or
170+
dt.type is np.bool_)):
171171
DPCTLQueue_Delete(QRef)
172172
raise TypeError("Only integer types, floating and complex "
173173
"floating types are supported.")

dpctl/tensor/_slicing.pxi

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ cdef object _basic_slice_meta(object ind, tuple shape,
4444
Raises IndexError for invalid index `ind`, and NotImplementedError
4545
if `ind` is an array.
4646
"""
47+
is_integral = lambda x: (
48+
isinstance(x, numbers.Integral) or callable(getattr(x, "__index__", None))
49+
)
4750
if ind is Ellipsis:
4851
return (shape, strides, offset)
4952
elif ind is None:
@@ -58,7 +61,8 @@ cdef object _basic_slice_meta(object ind, tuple shape,
5861
new_strides,
5962
offset + sl_start * strides[0]
6063
)
61-
elif isinstance(ind, numbers.Integral):
64+
elif is_integral(ind):
65+
ind = ind.__index__()
6266
if 0 <= ind < shape[0]:
6367
return (shape[1:], strides[1:], offset + ind * strides[0])
6468
elif -shape[0] <= ind < 0:
@@ -82,7 +86,7 @@ cdef object _basic_slice_meta(object ind, tuple shape,
8286
ellipses_count = ellipses_count + 1
8387
elif isinstance(i, slice):
8488
axes_referenced = axes_referenced + 1
85-
elif isinstance(i, numbers.Integral):
89+
elif is_integral(i):
8690
explicit_index = explicit_index + 1
8791
axes_referenced = axes_referenced + 1
8892
elif isinstance(i, list):
@@ -124,7 +128,8 @@ cdef object _basic_slice_meta(object ind, tuple shape,
124128
new_strides.append(str_i)
125129
new_offset = new_offset + sl_start * strides[k]
126130
k = k_new
127-
elif isinstance(ind_i, numbers.Integral):
131+
elif is_integral(ind_i):
132+
ind_i = ind_i.__index__()
128133
if 0 <= ind_i < shape[k]:
129134
k_new = k + 1
130135
new_offset = new_offset + ind_i * strides[k]

dpctl/tensor/_usmarray.pyx

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,54 @@ cdef class usm_ndarray:
491491
res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE)
492492
return res
493493

494+
def __bool__(self):
495+
if self.size == 1:
496+
mem_view = dpmem.as_usm_memory(self)
497+
return mem_view.copy_to_host().view(self.dtype).__bool__()
498+
499+
if self.size == 0:
500+
raise ValueError(
501+
"The truth value of an empty array is ambiguous"
502+
)
503+
504+
raise ValueError(
505+
"The truth value of an array with more than one element is "
506+
"ambiguous. Use a.any() or a.all()"
507+
)
508+
509+
def __float__(self):
510+
if self.size == 1:
511+
mem_view = dpmem.as_usm_memory(self)
512+
return mem_view.copy_to_host().view(self.dtype).__float__()
513+
514+
raise ValueError(
515+
"only size-1 arrays can be converted to Python scalars"
516+
)
517+
518+
def __complex__(self):
519+
if self.size == 1:
520+
mem_view = dpmem.as_usm_memory(self)
521+
return mem_view.copy_to_host().view(self.dtype).__complex__()
522+
523+
raise ValueError(
524+
"only size-1 arrays can be converted to Python scalars"
525+
)
526+
527+
def __int__(self):
528+
if self.size == 1:
529+
mem_view = dpmem.as_usm_memory(self)
530+
return mem_view.copy_to_host().view(self.dtype).__int__()
531+
532+
raise ValueError(
533+
"only size-1 arrays can be converted to Python scalars"
534+
)
535+
536+
def __index__(self):
537+
if np.issubdtype(self.dtype, np.integer):
538+
return int(self)
539+
540+
raise IndexError("only integer arrays are valid indices")
541+
494542
def to_device(self, target_device):
495543
"""
496544
Transfer array to target device

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,64 @@ def test_properties():
114114
assert isinstance(X.ndim, numbers.Integral)
115115

116116

117+
@pytest.mark.parametrize("func", [bool, float, int, complex])
118+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
119+
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
120+
def test_copy_scalar_with_func(func, shape, dtype):
121+
X = dpt.usm_ndarray(shape, dtype=dtype)
122+
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
123+
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
124+
assert func(X) == func(Y)
125+
126+
127+
@pytest.mark.parametrize(
128+
"method", ["__bool__", "__float__", "__int__", "__complex__"]
129+
)
130+
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
131+
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
132+
def test_copy_scalar_with_method(method, shape, dtype):
133+
X = dpt.usm_ndarray(shape, dtype=dtype)
134+
Y = np.arange(1, X.size + 1, dtype=dtype).reshape(shape)
135+
X.usm_data.copy_from_host(Y.reshape(-1).view("|u1"))
136+
assert getattr(X, method)() == getattr(Y, method)()
137+
138+
139+
@pytest.mark.parametrize("func", [bool, float, int, complex])
140+
@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)])
141+
def test_copy_scalar_invalid_shape(func, shape):
142+
X = dpt.usm_ndarray(shape)
143+
with pytest.raises(ValueError):
144+
func(X)
145+
146+
147+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
148+
@pytest.mark.parametrize("index_dtype", ["|i8"])
149+
def test_usm_ndarray_as_index(shape, index_dtype):
150+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
151+
Xnp = np.arange(1, X.size + 1, dtype=index_dtype).reshape(shape)
152+
X.usm_data.copy_from_host(Xnp.reshape(-1).view("|u1"))
153+
Y = np.arange(X.size + 1)
154+
assert Y[X] == Y[1]
155+
156+
157+
@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)])
158+
@pytest.mark.parametrize("index_dtype", ["|i8"])
159+
def test_usm_ndarray_as_index_invalid_shape(shape, index_dtype):
160+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
161+
Y = np.arange(X.size + 1)
162+
with pytest.raises(IndexError):
163+
Y[X]
164+
165+
166+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
167+
@pytest.mark.parametrize("index_dtype", ["|f8"])
168+
def test_usm_ndarray_as_index_invalid_dtype(shape, index_dtype):
169+
X = dpt.usm_ndarray(shape, dtype=index_dtype)
170+
Y = np.arange(X.size + 1)
171+
with pytest.raises(IndexError):
172+
Y[X]
173+
174+
117175
@pytest.mark.parametrize(
118176
"ind",
119177
[
@@ -251,6 +309,14 @@ def test_slicing_basic():
251309
Xusm[:, -128]
252310
with pytest.raises(TypeError):
253311
Xusm[{1, 2, 3, 4, 5, 6, 7}]
312+
X = dpt.usm_ndarray(10, "u1")
313+
X.usm_data.copy_from_host(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09")
314+
int(
315+
X[X[2]]
316+
) # check that objects with __index__ method can be used as indices
317+
Xh = dpm.as_usm_memory(X[X[2] : X[5]]).copy_to_host()
318+
Xnp = np.arange(0, 10, dtype="u1")
319+
assert np.array_equal(Xh, Xnp[Xnp[2] : Xnp[5]])
254320

255321

256322
def test_ctor_invalid_shape():
@@ -291,3 +357,19 @@ def test_usm_ndarray_props():
291357
except dpctl.SyclQueueCreationError:
292358
pytest.skip("Sycl device CPU was not detected")
293359
Xusm.to_device("cpu")
360+
361+
362+
def test_datapi_device():
363+
X = dpt.usm_ndarray(1)
364+
dev_t = type(X.device)
365+
with pytest.raises(TypeError):
366+
dev_t()
367+
dev_t.create_device(X.device)
368+
dev_t.create_device(X.sycl_queue)
369+
dev_t.create_device(X.sycl_device)
370+
dev_t.create_device(X.sycl_device.filter_string)
371+
dev_t.create_device(None)
372+
X.device.sycl_context
373+
X.device.sycl_queue
374+
X.device.sycl_device
375+
repr(X.device)

0 commit comments

Comments
 (0)