Skip to content

Commit 2be09e5

Browse files
Adding tests for dpt.divide
1 parent 89f8fe9 commit 2be09e5

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import ctypes
2+
3+
import numpy as np
4+
import pytest
5+
6+
import dpctl
7+
import dpctl.tensor as dpt
8+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
9+
10+
from .utils import _all_dtypes, _compare_dtypes, _usm_types
11+
12+
13+
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
14+
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
15+
def test_divide_dtype_matrix(op1_dtype, op2_dtype):
16+
q = get_queue_or_skip()
17+
skip_if_dtype_not_supported(op1_dtype, q)
18+
skip_if_dtype_not_supported(op2_dtype, q)
19+
20+
sz = 127
21+
ar1 = dpt.ones(sz, dtype=op1_dtype)
22+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
23+
24+
r = dpt.divide(ar1, ar2)
25+
assert isinstance(r, dpt.usm_ndarray)
26+
expected = np.divide(
27+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
28+
)
29+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
30+
assert r.shape == ar1.shape
31+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
32+
assert r.sycl_queue == ar1.sycl_queue
33+
34+
ar3 = dpt.ones(sz, dtype=op1_dtype)
35+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
36+
37+
r = dpt.divide(ar3[::-1], ar4[::2])
38+
assert isinstance(r, dpt.usm_ndarray)
39+
expected = np.divide(
40+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
41+
)
42+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
43+
assert r.shape == ar3.shape
44+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
45+
46+
47+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
48+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
49+
def test_divide_usm_type_matrix(op1_usm_type, op2_usm_type):
50+
get_queue_or_skip()
51+
52+
sz = 128
53+
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
54+
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
55+
56+
r = dpt.divide(ar1, ar2)
57+
assert isinstance(r, dpt.usm_ndarray)
58+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
59+
(op1_usm_type, op2_usm_type)
60+
)
61+
assert r.usm_type == expected_usm_type
62+
63+
64+
def test_divide_order():
65+
get_queue_or_skip()
66+
67+
ar1 = dpt.ones((20, 20), dtype="i4", order="C")
68+
ar2 = dpt.ones((20, 20), dtype="i4", order="C")
69+
r1 = dpt.divide(ar1, ar2, order="C")
70+
assert r1.flags.c_contiguous
71+
r2 = dpt.divide(ar1, ar2, order="F")
72+
assert r2.flags.f_contiguous
73+
r3 = dpt.divide(ar1, ar2, order="A")
74+
assert r3.flags.c_contiguous
75+
r4 = dpt.divide(ar1, ar2, order="K")
76+
assert r4.flags.c_contiguous
77+
78+
ar1 = dpt.ones((20, 20), dtype="i4", order="F")
79+
ar2 = dpt.ones((20, 20), dtype="i4", order="F")
80+
r1 = dpt.divide(ar1, ar2, order="C")
81+
assert r1.flags.c_contiguous
82+
r2 = dpt.divide(ar1, ar2, order="F")
83+
assert r2.flags.f_contiguous
84+
r3 = dpt.divide(ar1, ar2, order="A")
85+
assert r3.flags.f_contiguous
86+
r4 = dpt.divide(ar1, ar2, order="K")
87+
assert r4.flags.f_contiguous
88+
89+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
90+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
91+
r4 = dpt.divide(ar1, ar2, order="K")
92+
assert r4.strides == (20, -1)
93+
94+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
95+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
96+
r4 = dpt.divide(ar1, ar2, order="K")
97+
assert r4.strides == (-1, 20)
98+
99+
100+
def test_divide_broadcasting():
101+
get_queue_or_skip()
102+
103+
m = dpt.ones((100, 5), dtype="i4")
104+
v = dpt.arange(1, 6, dtype="i4")
105+
106+
r = dpt.divide(m, v)
107+
108+
expected = np.divide(
109+
np.ones((100, 5), dtype="i4"), np.arange(1, 6, dtype="i4")
110+
)
111+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
112+
113+
r2 = dpt.divide(v, m)
114+
expected2 = np.divide(
115+
np.arange(1, 6, dtype="i4"), np.ones((100, 5), dtype="i4")
116+
)
117+
assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()
118+
119+
120+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
121+
def test_divide_python_scalar(arr_dt):
122+
q = get_queue_or_skip()
123+
skip_if_dtype_not_supported(arr_dt, q)
124+
125+
X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q)
126+
py_ones = (
127+
bool(1),
128+
int(1),
129+
float(1),
130+
complex(1),
131+
np.float32(1),
132+
ctypes.c_int(1),
133+
)
134+
for sc in py_ones:
135+
R = dpt.divide(X, sc)
136+
assert isinstance(R, dpt.usm_ndarray)
137+
R = dpt.divide(sc, X)
138+
assert isinstance(R, dpt.usm_ndarray)
139+
140+
141+
class MockArray:
142+
def __init__(self, arr):
143+
self.data_ = arr
144+
145+
@property
146+
def __sycl_usm_array_interface__(self):
147+
return self.data_.__sycl_usm_array_interface__
148+
149+
150+
def test_divide_mock_array():
151+
get_queue_or_skip()
152+
a = dpt.arange(10)
153+
b = dpt.ones(10)
154+
c = MockArray(b)
155+
r = dpt.divide(a, c)
156+
assert isinstance(r, dpt.usm_ndarray)
157+
158+
159+
def test_divide_canary_mock_array():
160+
get_queue_or_skip()
161+
a = dpt.arange(10)
162+
163+
class Canary:
164+
def __init__(self):
165+
pass
166+
167+
@property
168+
def __sycl_usm_array_interface__(self):
169+
return None
170+
171+
c = Canary()
172+
with pytest.raises(ValueError):
173+
dpt.divide(a, c)

0 commit comments

Comments
 (0)