Skip to content

Commit ab701e1

Browse files
densmirnoleksandr-pavlyk
authored andcommitted
Add methods __bool__, __float__ and __int__ to usm_ndarray
1 parent b97efdf commit ab701e1

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,36 @@ cdef class usm_ndarray:
491491
res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE)
492492
return res
493493

494+
def __bool__(self):
495+
if self.size == 1:
496+
return self.usm_data.copy_to_host().view(self.dtype).__bool__()
497+
498+
if self.size == 0:
499+
raise ValueError(
500+
"The truth value of an empty array is ambiguous"
501+
)
502+
503+
raise ValueError(
504+
"The truth value of an array with more than one element is "
505+
"ambiguous. Use a.any() or a.all()"
506+
)
507+
508+
def __float__(self):
509+
if self.size == 1:
510+
return self.usm_data.copy_to_host().view(self.dtype).__float__()
511+
512+
raise ValueError(
513+
"only size-1 arrays can be converted to Python scalars"
514+
)
515+
516+
def __int__(self):
517+
if self.size == 1:
518+
return self.usm_data.copy_to_host().view(self.dtype).__int__()
519+
520+
raise ValueError(
521+
"only size-1 arrays can be converted to Python scalars"
522+
)
523+
494524
def to_device(self, target_device):
495525
"""
496526
Transfer array to target device

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,32 @@ def test_properties():
114114
assert isinstance(X.ndim, numbers.Integral)
115115

116116

117+
@pytest.mark.parametrize("func", [bool, float, int])
118+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
119+
def test_copy_scalar_with_func(func, shape):
120+
X = dpt.usm_ndarray(shape)
121+
Y = np.arange(1, X.size + 1, dtype=X.dtype)
122+
X.usm_data.copy_from_host(Y.view("|u1"))
123+
assert func(X) == func(Y)
124+
125+
126+
@pytest.mark.parametrize("method", ["__bool__", "__float__", "__int__"])
127+
@pytest.mark.parametrize("shape", [(1,), (1, 1), (1, 1, 1)])
128+
def test_copy_scalar_with_method(method, shape):
129+
X = dpt.usm_ndarray(shape)
130+
Y = np.arange(1, X.size + 1, dtype=X.dtype)
131+
X.usm_data.copy_from_host(Y.view("|u1"))
132+
assert getattr(X, method)() == getattr(Y, method)()
133+
134+
135+
@pytest.mark.parametrize("func", [bool, float, int])
136+
@pytest.mark.parametrize("shape", [(2,), (1, 2), (3, 4, 5), (0,)])
137+
def test_copy_scalar_invalid_shape(func, shape):
138+
X = dpt.usm_ndarray(shape)
139+
with pytest.raises(ValueError):
140+
func(X)
141+
142+
117143
@pytest.mark.parametrize(
118144
"ind",
119145
[

0 commit comments

Comments
 (0)