Skip to content

Commit ed4c5b4

Browse files
More tests to restore coverage of _usmarray.pyx to previous levels
1 parent c1f0cf9 commit ed4c5b4

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def _dispatch_binary_elementwise(ary, name, other):
6060
mod = ary.__array_namespace__()
6161
except AttributeError:
6262
return NotImplemented
63-
mod = ary.__array_namespace__()
6463
if mod is None and "dpnp" in sys.modules:
6564
fn = getattr(sys.modules["dpnp"], name)
6665
if callable(fn):

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ def test_dtypes_invalid(dtype):
9494
dpt.usm_ndarray((1,), dtype=dtype)
9595

9696

97-
def test_properties():
97+
@pytest.mark.parametrize("dt", ["d", "c16"])
98+
def test_properties(dt):
9899
"""
99100
Test that properties execute
100101
"""
101-
X = dpt.usm_ndarray((3, 4, 5), dtype="c16")
102+
X = dpt.usm_ndarray((3, 4, 5), dtype=dt)
102103
assert isinstance(X.sycl_queue, dpctl.SyclQueue)
103104
assert isinstance(X.sycl_device, dpctl.SyclDevice)
104105
assert isinstance(X.sycl_context, dpctl.SyclContext)
@@ -113,6 +114,7 @@ def test_properties():
113114
assert isinstance(X.size, numbers.Integral)
114115
assert isinstance(X.nbytes, numbers.Integral)
115116
assert isinstance(X.ndim, numbers.Integral)
117+
assert isinstance(X._pointer, numbers.Integral)
116118

117119

118120
@pytest.mark.parametrize("func", [bool, float, int, complex])
@@ -708,6 +710,13 @@ def relaxed_strides_equal(st1, st2, sh):
708710
X.shape = sh_f
709711
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
710712

713+
X = dpt.usm_ndarray(sh_s, dtype="d")
714+
with pytest.raises(TypeError):
715+
X.shape = "abcbe"
716+
X = dpt.usm_ndarray((4, 4), dtype="d")[::2, ::2]
717+
with pytest.raises(AttributeError):
718+
X.shape = (4,)
719+
711720

712721
def test_len():
713722
X = dpt.usm_ndarray(1, "i4")
@@ -749,3 +758,12 @@ def test_astype():
749758
assert np.allclose(dpt.to_numpy(Y), np.full((5, 5), 7, dtype="f2"))
750759
Y = dpt.astype(X, "i4", order="K", copy=False)
751760
assert Y.usm_data is X.usm_data
761+
762+
763+
def test_ctor_invalid():
764+
m = dpm.MemoryUSMShared(12)
765+
with pytest.raises(ValueError):
766+
dpt.usm_ndarray((4,), dtype="i4", buffer=m)
767+
m = dpm.MemoryUSMShared(64)
768+
with pytest.raises(ValueError):
769+
dpt.usm_ndarray((4,), dtype="u1", buffer=m, strides={"not": "valid"})

dpctl/tests/test_usm_ndarray_operators.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,33 @@
2121

2222

2323
class Dummy:
24-
def __getattr__(self, name):
25-
def first_arg(*args, **kwargs):
26-
return args[0]
24+
@staticmethod
25+
def __abs__(a):
26+
return a
2727

28-
return first_arg
28+
@staticmethod
29+
def __add__(a, b):
30+
if isinstance(a, dpt.usm_ndarray):
31+
return a
32+
else:
33+
return b
2934

35+
@staticmethod
36+
def __sub__(a, b):
37+
if isinstance(a, dpt.usm_ndarray):
38+
return a
39+
else:
40+
return b
3041

31-
@pytest.mark.parametrize("namespace", [None, Dummy])
42+
@staticmethod
43+
def __mul__(a, b):
44+
if isinstance(a, dpt.usm_ndarray):
45+
return a
46+
else:
47+
return b
48+
49+
50+
@pytest.mark.parametrize("namespace", [None, Dummy()])
3251
def test_fp_ops(namespace):
3352
X = dpt.usm_ndarray(1, "d")
3453
X._set_namespace(namespace)

0 commit comments

Comments
 (0)