Skip to content

Commit dc386eb

Browse files
Split tests of elementwise functions into separate files
1 parent 8a8411a commit dc386eb

File tree

10 files changed

+762
-701
lines changed

10 files changed

+762
-701
lines changed

dpctl/tests/elementwise/__init__.py

Whitespace-only changes.

dpctl/tests/elementwise/test_abs.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import itertools
2+
3+
import numpy as np
4+
import pytest
5+
6+
import dpctl.tensor as dpt
7+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
8+
9+
from .utils import _all_dtypes, _usm_types
10+
11+
12+
@pytest.mark.parametrize("dtype", _all_dtypes)
13+
def test_abs_out_type(dtype):
14+
q = get_queue_or_skip()
15+
skip_if_dtype_not_supported(dtype, q)
16+
17+
arg_dt = np.dtype(dtype)
18+
X = dpt.asarray(0, dtype=arg_dt, sycl_queue=q)
19+
if np.issubdtype(arg_dt, np.complexfloating):
20+
type_map = {
21+
np.dtype("c8"): np.dtype("f4"),
22+
np.dtype("c16"): np.dtype("f8"),
23+
}
24+
assert dpt.abs(X).dtype == type_map[arg_dt]
25+
else:
26+
assert dpt.abs(X).dtype == arg_dt
27+
28+
29+
@pytest.mark.parametrize("usm_type", _usm_types)
30+
def test_abs_usm_type(usm_type):
31+
q = get_queue_or_skip()
32+
33+
arg_dt = np.dtype("i4")
34+
input_shape = (10, 10, 10, 10)
35+
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
36+
X[..., 0::2] = 1
37+
X[..., 1::2] = 0
38+
39+
Y = dpt.abs(X)
40+
assert Y.usm_type == X.usm_type
41+
assert Y.sycl_queue == X.sycl_queue
42+
assert Y.flags.c_contiguous
43+
44+
expected_Y = dpt.asnumpy(X)
45+
assert np.allclose(dpt.asnumpy(Y), expected_Y)
46+
47+
48+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
49+
def test_abs_order(dtype):
50+
q = get_queue_or_skip()
51+
skip_if_dtype_not_supported(dtype, q)
52+
53+
arg_dt = np.dtype(dtype)
54+
input_shape = (10, 10, 10, 10)
55+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
56+
X[..., 0::2] = 1
57+
X[..., 1::2] = 0
58+
59+
for ord in ["C", "F", "A", "K"]:
60+
for perms in itertools.permutations(range(4)):
61+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
62+
Y = dpt.abs(U, order=ord)
63+
expected_Y = np.ones(Y.shape, dtype=Y.dtype)
64+
expected_Y[..., 1::2] = 0
65+
expected_Y = np.transpose(expected_Y, perms)
66+
assert np.allclose(dpt.asnumpy(Y), expected_Y)
67+
68+
69+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
70+
def test_abs_complex(dtype):
71+
q = get_queue_or_skip()
72+
skip_if_dtype_not_supported(dtype, q)
73+
74+
arg_dt = np.dtype(dtype)
75+
input_shape = (10, 10, 10, 10)
76+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
77+
Xnp = np.random.standard_normal(
78+
size=input_shape
79+
) + 1j * np.random.standard_normal(size=input_shape)
80+
Xnp = Xnp.astype(arg_dt)
81+
X[...] = Xnp
82+
83+
for ord in ["C", "F", "A", "K"]:
84+
for perms in itertools.permutations(range(4)):
85+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
86+
Y = dpt.abs(U, order=ord)
87+
expected_Y = np.abs(np.transpose(Xnp[:, ::-1, ::-1, :], perms))
88+
tol = dpt.finfo(Y.dtype).resolution
89+
np.testing.assert_allclose(
90+
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
91+
)

dpctl/tests/elementwise/test_add.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import ctypes
2+
3+
import numpy as np
4+
import pytest
5+
6+
import dpctl
7+
import dpctl.tensor as dpt
8+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
9+
10+
from .utils import _all_dtypes, _compare_dtypes, _usm_types
11+
12+
13+
@pytest.mark.parametrize("op1_dtype", _all_dtypes)
14+
@pytest.mark.parametrize("op2_dtype", _all_dtypes)
15+
def test_add_dtype_matrix(op1_dtype, op2_dtype):
16+
q = get_queue_or_skip()
17+
skip_if_dtype_not_supported(op1_dtype, q)
18+
skip_if_dtype_not_supported(op2_dtype, q)
19+
20+
sz = 127
21+
ar1 = dpt.ones(sz, dtype=op1_dtype)
22+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
23+
24+
r = dpt.add(ar1, ar2)
25+
assert isinstance(r, dpt.usm_ndarray)
26+
expected_dtype = np.add(
27+
np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype)
28+
).dtype
29+
assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q)
30+
assert r.shape == ar1.shape
31+
assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all()
32+
assert r.sycl_queue == ar1.sycl_queue
33+
34+
ar3 = dpt.ones(sz, dtype=op1_dtype)
35+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
36+
37+
r = dpt.add(ar3[::-1], ar4[::2])
38+
assert isinstance(r, dpt.usm_ndarray)
39+
expected_dtype = np.add(
40+
np.zeros(1, dtype=op1_dtype), np.zeros(1, dtype=op2_dtype)
41+
).dtype
42+
assert _compare_dtypes(r.dtype, expected_dtype, sycl_queue=q)
43+
assert r.shape == ar3.shape
44+
assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all()
45+
46+
47+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
48+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
49+
def test_add_usm_type_matrix(op1_usm_type, op2_usm_type):
50+
get_queue_or_skip()
51+
52+
sz = 128
53+
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
54+
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
55+
56+
r = dpt.add(ar1, ar2)
57+
assert isinstance(r, dpt.usm_ndarray)
58+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
59+
(op1_usm_type, op2_usm_type)
60+
)
61+
assert r.usm_type == expected_usm_type
62+
63+
64+
def test_add_order():
65+
get_queue_or_skip()
66+
67+
ar1 = dpt.ones((20, 20), dtype="i4", order="C")
68+
ar2 = dpt.ones((20, 20), dtype="i4", order="C")
69+
r1 = dpt.add(ar1, ar2, order="C")
70+
assert r1.flags.c_contiguous
71+
r2 = dpt.add(ar1, ar2, order="F")
72+
assert r2.flags.f_contiguous
73+
r3 = dpt.add(ar1, ar2, order="A")
74+
assert r3.flags.c_contiguous
75+
r4 = dpt.add(ar1, ar2, order="K")
76+
assert r4.flags.c_contiguous
77+
78+
ar1 = dpt.ones((20, 20), dtype="i4", order="F")
79+
ar2 = dpt.ones((20, 20), dtype="i4", order="F")
80+
r1 = dpt.add(ar1, ar2, order="C")
81+
assert r1.flags.c_contiguous
82+
r2 = dpt.add(ar1, ar2, order="F")
83+
assert r2.flags.f_contiguous
84+
r3 = dpt.add(ar1, ar2, order="A")
85+
assert r3.flags.f_contiguous
86+
r4 = dpt.add(ar1, ar2, order="K")
87+
assert r4.flags.f_contiguous
88+
89+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
90+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
91+
r4 = dpt.add(ar1, ar2, order="K")
92+
assert r4.strides == (20, -1)
93+
94+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
95+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
96+
r4 = dpt.add(ar1, ar2, order="K")
97+
assert r4.strides == (-1, 20)
98+
99+
100+
def test_add_broadcasting():
101+
get_queue_or_skip()
102+
103+
m = dpt.ones((100, 5), dtype="i4")
104+
v = dpt.arange(5, dtype="i4")
105+
106+
r = dpt.add(m, v)
107+
108+
assert (dpt.asnumpy(r) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
109+
110+
r2 = dpt.add(v, m)
111+
assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
112+
113+
114+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
115+
def test_add_python_scalar(arr_dt):
116+
q = get_queue_or_skip()
117+
skip_if_dtype_not_supported(arr_dt, q)
118+
119+
X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q)
120+
py_zeros = (
121+
bool(0),
122+
int(0),
123+
float(0),
124+
complex(0),
125+
np.float32(0),
126+
ctypes.c_int(0),
127+
)
128+
for sc in py_zeros:
129+
R = dpt.add(X, sc)
130+
assert isinstance(R, dpt.usm_ndarray)
131+
R = dpt.add(sc, X)
132+
assert isinstance(R, dpt.usm_ndarray)
133+
134+
135+
class MockArray:
136+
def __init__(self, arr):
137+
self.data_ = arr
138+
139+
@property
140+
def __sycl_usm_array_interface__(self):
141+
return self.data_.__sycl_usm_array_interface__
142+
143+
144+
def test_add_mock_array():
145+
get_queue_or_skip()
146+
a = dpt.arange(10)
147+
b = dpt.ones(10)
148+
c = MockArray(b)
149+
r = dpt.add(a, c)
150+
assert isinstance(r, dpt.usm_ndarray)
151+
152+
153+
def test_add_canary_mock_array():
154+
get_queue_or_skip()
155+
a = dpt.arange(10)
156+
157+
class Canary:
158+
def __init__(self):
159+
pass
160+
161+
@property
162+
def __sycl_usm_array_interface__(self):
163+
return None
164+
165+
c = Canary()
166+
with pytest.raises(ValueError):
167+
dpt.add(a, c)

dpctl/tests/elementwise/test_cos.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import itertools
2+
3+
import numpy as np
4+
import pytest
5+
6+
import dpctl.tensor as dpt
7+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
8+
9+
from .utils import _all_dtypes, _map_to_device_dtype
10+
11+
12+
@pytest.mark.parametrize("dtype", _all_dtypes)
13+
def test_cos_out_type(dtype):
14+
q = get_queue_or_skip()
15+
skip_if_dtype_not_supported(dtype, q)
16+
17+
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
18+
expected_dtype = np.cos(np.array(0, dtype=dtype)).dtype
19+
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
20+
assert dpt.cos(X).dtype == expected_dtype
21+
22+
23+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
24+
def test_cos_output(dtype):
25+
q = get_queue_or_skip()
26+
skip_if_dtype_not_supported(dtype, q)
27+
28+
n_seq = 100
29+
n_rep = 137
30+
31+
Xnp = np.linspace(-np.pi / 4, np.pi / 4, num=n_seq, dtype=dtype)
32+
X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q)
33+
34+
Y = dpt.cos(X)
35+
tol = 8 * dpt.finfo(Y.dtype).resolution
36+
37+
np.testing.assert_allclose(
38+
dpt.asnumpy(Y), np.repeat(np.cos(Xnp), n_rep), atol=tol, rtol=tol
39+
)
40+
41+
42+
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
43+
def test_cos_usm_type(usm_type):
44+
q = get_queue_or_skip()
45+
46+
arg_dt = np.dtype("f4")
47+
input_shape = (10, 10, 10, 10)
48+
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
49+
X[..., 0::2] = np.pi / 6
50+
X[..., 1::2] = np.pi / 3
51+
52+
Y = dpt.cos(X)
53+
assert Y.usm_type == X.usm_type
54+
assert Y.sycl_queue == X.sycl_queue
55+
assert Y.flags.c_contiguous
56+
57+
expected_Y = np.empty(input_shape, dtype=arg_dt)
58+
expected_Y[..., 0::2] = np.cos(np.float32(np.pi / 6))
59+
expected_Y[..., 1::2] = np.cos(np.float32(np.pi / 3))
60+
tol = 8 * dpt.finfo(Y.dtype).resolution
61+
62+
np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
63+
64+
65+
@pytest.mark.parametrize("dtype", _all_dtypes)
66+
def test_cos_order(dtype):
67+
q = get_queue_or_skip()
68+
skip_if_dtype_not_supported(dtype, q)
69+
70+
arg_dt = np.dtype(dtype)
71+
input_shape = (10, 10, 10, 10)
72+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
73+
X[..., 0::2] = np.pi / 6
74+
X[..., 1::2] = np.pi / 3
75+
76+
for ord in ["C", "F", "A", "K"]:
77+
for perms in itertools.permutations(range(4)):
78+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
79+
Y = dpt.cos(U, order=ord)
80+
expected_Y = np.cos(dpt.asnumpy(U))
81+
tol = 8 * max(
82+
dpt.finfo(Y.dtype).resolution,
83+
np.finfo(expected_Y.dtype).resolution,
84+
)
85+
np.testing.assert_allclose(
86+
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
87+
)

0 commit comments

Comments
 (0)