Skip to content

Commit a3e209f

Browse files
More tests to restore coverage of _usmarray.pyx to previous levels
1 parent c1f0cf9 commit a3e209f

File tree

3 files changed

+52
-8
lines changed

3 files changed

+52
-8
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def _dispatch_binary_elementwise(ary, name, other):
6060
mod = ary.__array_namespace__()
6161
except AttributeError:
6262
return NotImplemented
63-
mod = ary.__array_namespace__()
6463
if mod is None and "dpnp" in sys.modules:
6564
fn = getattr(sys.modules["dpnp"], name)
6665
if callable(fn):

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,12 @@ def test_dtypes_invalid(dtype):
9494
dpt.usm_ndarray((1,), dtype=dtype)
9595

9696

97-
def test_properties():
97+
@pytest.mark.parametrize("dt", ["d", "c16"])
98+
def test_properties(dt):
9899
"""
99100
Test that properties execute
100101
"""
101-
X = dpt.usm_ndarray((3, 4, 5), dtype="c16")
102+
X = dpt.usm_ndarray((3, 4, 5), dtype=dt)
102103
assert isinstance(X.sycl_queue, dpctl.SyclQueue)
103104
assert isinstance(X.sycl_device, dpctl.SyclDevice)
104105
assert isinstance(X.sycl_context, dpctl.SyclContext)
@@ -113,6 +114,7 @@ def test_properties():
113114
assert isinstance(X.size, numbers.Integral)
114115
assert isinstance(X.nbytes, numbers.Integral)
115116
assert isinstance(X.ndim, numbers.Integral)
117+
assert isinstance(X._pointer, numbers.Integral)
116118

117119

118120
@pytest.mark.parametrize("func", [bool, float, int, complex])
@@ -145,6 +147,14 @@ def test_copy_scalar_invalid_shape(func, shape):
145147
func(X)
146148

147149

150+
def test_index_noninteger():
151+
import operator
152+
153+
X = dpt.usm_ndarray(1, "d")
154+
with pytest.raises(IndexError):
155+
operator.index(X)
156+
157+
148158
@pytest.mark.parametrize(
149159
"ind",
150160
[
@@ -708,6 +718,13 @@ def relaxed_strides_equal(st1, st2, sh):
708718
X.shape = sh_f
709719
assert relaxed_strides_equal(X.strides, cc_strides(sh_f), sh_f)
710720

721+
X = dpt.usm_ndarray(sh_s, dtype="d")
722+
with pytest.raises(TypeError):
723+
X.shape = "abcbe"
724+
X = dpt.usm_ndarray((4, 4), dtype="d")[::2, ::2]
725+
with pytest.raises(AttributeError):
726+
X.shape = (4,)
727+
711728

712729
def test_len():
713730
X = dpt.usm_ndarray(1, "i4")
@@ -749,3 +766,12 @@ def test_astype():
749766
assert np.allclose(dpt.to_numpy(Y), np.full((5, 5), 7, dtype="f2"))
750767
Y = dpt.astype(X, "i4", order="K", copy=False)
751768
assert Y.usm_data is X.usm_data
769+
770+
771+
def test_ctor_invalid():
772+
m = dpm.MemoryUSMShared(12)
773+
with pytest.raises(ValueError):
774+
dpt.usm_ndarray((4,), dtype="i4", buffer=m)
775+
m = dpm.MemoryUSMShared(64)
776+
with pytest.raises(ValueError):
777+
dpt.usm_ndarray((4,), dtype="u1", buffer=m, strides={"not": "valid"})

dpctl/tests/test_usm_ndarray_operators.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,33 @@
2121

2222

2323
class Dummy:
24-
def __getattr__(self, name):
25-
def first_arg(*args, **kwargs):
26-
return args[0]
24+
@staticmethod
25+
def __abs__(a):
26+
return a
2727

28-
return first_arg
28+
@staticmethod
29+
def __add__(a, b):
30+
if isinstance(a, dpt.usm_ndarray):
31+
return a
32+
else:
33+
return b
2934

35+
@staticmethod
36+
def __sub__(a, b):
37+
if isinstance(a, dpt.usm_ndarray):
38+
return a
39+
else:
40+
return b
3041

31-
@pytest.mark.parametrize("namespace", [None, Dummy])
42+
@staticmethod
43+
def __mul__(a, b):
44+
if isinstance(a, dpt.usm_ndarray):
45+
return a
46+
else:
47+
return b
48+
49+
50+
@pytest.mark.parametrize("namespace", [None, Dummy()])
3251
def test_fp_ops(namespace):
3352
X = dpt.usm_ndarray(1, "d")
3453
X._set_namespace(namespace)

0 commit comments

Comments
 (0)