Skip to content

Commit 5d37d74

Browse files
committed
Added tests for dpctl.tensor.sqrt
1 parent 4952132 commit 5d37d74

File tree

1 file changed

+133
-0
lines changed

1 file changed

+133
-0
lines changed

dpctl/tests/test_tensor_sqrt.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import itertools
2+
3+
import numpy as np
4+
import pytest
5+
from numpy.testing import assert_equal
6+
7+
import dpctl.tensor as dpt
8+
import dpctl.tensor._type_utils as tu
9+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
10+
11+
_all_dtypes = [
12+
"b1",
13+
"i1",
14+
"u1",
15+
"i2",
16+
"u2",
17+
"i4",
18+
"u4",
19+
"i8",
20+
"u8",
21+
"f2",
22+
"f4",
23+
"f8",
24+
"c8",
25+
"c16",
26+
]
27+
_usm_types = ["device", "shared", "host"]
28+
29+
30+
def _map_to_device_dtype(dt, dev):
31+
return tu._to_device_supported_dtype(dt, dev)
32+
33+
34+
@pytest.mark.parametrize("dtype", _all_dtypes)
35+
def test_sqrt_out_type(dtype):
36+
q = get_queue_or_skip()
37+
skip_if_dtype_not_supported(dtype, q)
38+
39+
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
40+
expected_dtype = np.sqrt(np.array(0, dtype=dtype)).dtype
41+
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
42+
assert dpt.sqrt(X).dtype == expected_dtype
43+
44+
45+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
46+
def test_sqrt_output_contig(dtype):
47+
q = get_queue_or_skip()
48+
skip_if_dtype_not_supported(dtype, q)
49+
50+
n_seq = 1027
51+
52+
X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)
53+
Xnp = dpt.asnumpy(X)
54+
55+
Y = dpt.sqrt(X)
56+
tol = 8 * dpt.finfo(Y.dtype).resolution
57+
58+
np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
59+
60+
61+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
62+
def test_sqrt_output_strided(dtype):
63+
q = get_queue_or_skip()
64+
skip_if_dtype_not_supported(dtype, q)
65+
66+
n_seq = 2054
67+
68+
X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)[::-2]
69+
Xnp = dpt.asnumpy(X)
70+
71+
Y = dpt.sqrt(X)
72+
tol = 8 * dpt.finfo(Y.dtype).resolution
73+
74+
np.testing.assert_allclose(dpt.asnumpy(Y), np.sqrt(Xnp), atol=tol, rtol=tol)
75+
76+
77+
@pytest.mark.parametrize("usm_type", _usm_types)
78+
def test_sqrt_usm_type(usm_type):
79+
q = get_queue_or_skip()
80+
81+
arg_dt = np.dtype("f4")
82+
input_shape = (10, 10, 10, 10)
83+
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
84+
X[..., 0::2] = 16.0
85+
X[..., 1::2] = 23.0
86+
87+
Y = dpt.sqrt(X)
88+
assert Y.usm_type == X.usm_type
89+
assert Y.sycl_queue == X.sycl_queue
90+
assert Y.flags.c_contiguous
91+
92+
expected_Y = np.empty(input_shape, dtype=arg_dt)
93+
expected_Y[..., 0::2] = np.sqrt(np.float32(16.0))
94+
expected_Y[..., 1::2] = np.sqrt(np.float32(23.0))
95+
tol = 8 * dpt.finfo(Y.dtype).resolution
96+
97+
np.testing.assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
98+
99+
100+
@pytest.mark.parametrize("dtype", _all_dtypes)
101+
def test_sqrt_order(dtype):
102+
q = get_queue_or_skip()
103+
skip_if_dtype_not_supported(dtype, q)
104+
105+
arg_dt = np.dtype(dtype)
106+
input_shape = (10, 10, 10, 10)
107+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
108+
X[..., 0::2] = 16.0
109+
X[..., 1::2] = 23.0
110+
111+
for ord in ["C", "F", "A", "K"]:
112+
for perms in itertools.permutations(range(4)):
113+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
114+
Y = dpt.sqrt(U, order=ord)
115+
expected_Y = np.sqrt(dpt.asnumpy(U))
116+
tol = 8 * max(
117+
dpt.finfo(Y.dtype).resolution,
118+
np.finfo(expected_Y.dtype).resolution,
119+
)
120+
np.testing.assert_allclose(
121+
dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol
122+
)
123+
124+
125+
def test_sqrt_special_cases():
126+
q = get_queue_or_skip()
127+
128+
X = dpt.asarray(
129+
[dpt.nan, -1.0, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q
130+
)
131+
Xnp = dpt.asnumpy(X)
132+
133+
assert_equal(dpt.asnumpy(dpt.sqrt(X)), np.sqrt(Xnp))

0 commit comments

Comments
 (0)