Skip to content

Commit 5ae12c4

Browse files
committed
Adds tests for cbrt, copysign, and exp2
1 parent 7a19a10 commit 5ae12c4

File tree

3 files changed

+461
-0
lines changed

3 files changed

+461
-0
lines changed

dpctl/tests/elementwise/test_cbrt.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import itertools
18+
import warnings
19+
20+
import numpy as np
21+
import pytest
22+
from numpy.testing import assert_allclose, assert_equal
23+
24+
import dpctl.tensor as dpt
25+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
26+
27+
from .utils import (
28+
_all_dtypes,
29+
_complex_fp_dtypes,
30+
_map_to_device_dtype,
31+
_real_fp_dtypes,
32+
_usm_types,
33+
)
34+
35+
36+
@pytest.mark.parametrize("dtype", _all_dtypes)
37+
def test_cbrt_out_type(dtype):
38+
q = get_queue_or_skip()
39+
skip_if_dtype_not_supported(dtype, q)
40+
41+
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
42+
expected_dtype = np.cbrt(np.array(0, dtype=dtype)).dtype
43+
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
44+
assert dpt.cbrt(X).dtype == expected_dtype
45+
46+
47+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
48+
def test_cbrt_output_contig(dtype):
49+
q = get_queue_or_skip()
50+
skip_if_dtype_not_supported(dtype, q)
51+
52+
n_seq = 1027
53+
54+
X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)
55+
Xnp = dpt.asnumpy(X)
56+
57+
Y = dpt.cbrt(X)
58+
tol = 8 * dpt.finfo(Y.dtype).resolution
59+
60+
assert_allclose(dpt.asnumpy(Y), np.cbrt(Xnp), atol=tol, rtol=tol)
61+
62+
63+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8", "c8", "c16"])
64+
def test_cbrt_output_strided(dtype):
65+
q = get_queue_or_skip()
66+
skip_if_dtype_not_supported(dtype, q)
67+
68+
n_seq = 2054
69+
70+
X = dpt.linspace(0, 13, num=n_seq, dtype=dtype, sycl_queue=q)[::-2]
71+
Xnp = dpt.asnumpy(X)
72+
73+
Y = dpt.cbrt(X)
74+
tol = 8 * dpt.finfo(Y.dtype).resolution
75+
76+
assert_allclose(dpt.asnumpy(Y), np.cbrt(Xnp), atol=tol, rtol=tol)
77+
78+
79+
@pytest.mark.parametrize("usm_type", _usm_types)
80+
def test_cbrt_usm_type(usm_type):
81+
q = get_queue_or_skip()
82+
83+
arg_dt = np.dtype("f4")
84+
input_shape = (10, 10, 10, 10)
85+
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
86+
X[..., 0::2] = 16.0
87+
X[..., 1::2] = 23.0
88+
89+
Y = dpt.cbrt(X)
90+
assert Y.usm_type == X.usm_type
91+
assert Y.sycl_queue == X.sycl_queue
92+
assert Y.flags.c_contiguous
93+
94+
expected_Y = np.empty(input_shape, dtype=arg_dt)
95+
expected_Y[..., 0::2] = np.cbrt(np.float32(16.0))
96+
expected_Y[..., 1::2] = np.cbrt(np.float32(23.0))
97+
tol = 8 * dpt.finfo(Y.dtype).resolution
98+
99+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
100+
101+
102+
@pytest.mark.parametrize("dtype", _all_dtypes)
103+
def test_cbrt_order(dtype):
104+
q = get_queue_or_skip()
105+
skip_if_dtype_not_supported(dtype, q)
106+
107+
arg_dt = np.dtype(dtype)
108+
input_shape = (10, 10, 10, 10)
109+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
110+
X[..., 0::2] = 16.0
111+
X[..., 1::2] = 23.0
112+
113+
for ord in ["C", "F", "A", "K"]:
114+
for perms in itertools.permutations(range(4)):
115+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
116+
Y = dpt.cbrt(U, order=ord)
117+
expected_Y = np.cbrt(dpt.asnumpy(U))
118+
tol = 8 * max(
119+
dpt.finfo(Y.dtype).resolution,
120+
np.finfo(expected_Y.dtype).resolution,
121+
)
122+
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
123+
124+
125+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
126+
def test_cbrt_special_cases():
127+
q = get_queue_or_skip()
128+
129+
X = dpt.asarray(
130+
[dpt.nan, -1.0, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q
131+
)
132+
Xnp = dpt.asnumpy(X)
133+
134+
assert_equal(dpt.asnumpy(dpt.cbrt(X)), np.cbrt(Xnp))
135+
136+
137+
@pytest.mark.parametrize("dtype", _real_fp_dtypes)
138+
def test_cbrt_real_fp_special_values(dtype):
139+
q = get_queue_or_skip()
140+
skip_if_dtype_not_supported(dtype, q)
141+
142+
nans_ = [dpt.nan, -dpt.nan]
143+
infs_ = [dpt.inf, -dpt.inf]
144+
finites_ = [-1.0, -0.0, 0.0, 1.0]
145+
inps_ = nans_ + infs_ + finites_
146+
147+
x = dpt.asarray(inps_, dtype=dtype)
148+
r = dpt.cbrt(x)
149+
150+
with warnings.catch_warnings():
151+
warnings.simplefilter("ignore")
152+
expected_np = np.cbrt(np.asarray(inps_, dtype=dtype))
153+
154+
expected = dpt.asarray(expected_np, dtype=dtype)
155+
tol = dpt.finfo(r.dtype).resolution
156+
157+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
158+
159+
160+
@pytest.mark.broken_complex
161+
@pytest.mark.parametrize("dtype", _complex_fp_dtypes)
162+
def test_cbrt_complex_fp_special_values(dtype):
163+
q = get_queue_or_skip()
164+
skip_if_dtype_not_supported(dtype, q)
165+
166+
nans_ = [dpt.nan, -dpt.nan]
167+
infs_ = [dpt.inf, -dpt.inf]
168+
finites_ = [-1.0, -0.0, 0.0, 1.0]
169+
inps_ = nans_ + infs_ + finites_
170+
c_ = [complex(*v) for v in itertools.product(inps_, repeat=2)]
171+
172+
z = dpt.asarray(c_, dtype=dtype)
173+
r = dpt.cbrt(z)
174+
175+
with warnings.catch_warnings():
176+
warnings.simplefilter("ignore")
177+
expected_np = np.cbrt(np.asarray(c_, dtype=dtype))
178+
179+
expected = dpt.asarray(expected_np, dtype=dtype)
180+
tol = dpt.finfo(r.dtype).resolution
181+
182+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl.tensor as dpt
23+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
24+
25+
from .utils import _compare_dtypes, _real_fp_dtypes
26+
27+
28+
@pytest.mark.parametrize("op1_dtype", _real_fp_dtypes)
29+
@pytest.mark.parametrize("op2_dtype", _real_fp_dtypes)
30+
def test_copysign_dtype_matrix(op1_dtype, op2_dtype):
31+
q = get_queue_or_skip()
32+
skip_if_dtype_not_supported(op1_dtype, q)
33+
skip_if_dtype_not_supported(op2_dtype, q)
34+
35+
sz = 127
36+
ar1 = dpt.ones(sz, dtype=op1_dtype)
37+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
38+
39+
r = dpt.copysign(ar1, ar2)
40+
assert isinstance(r, dpt.usm_ndarray)
41+
expected = np.copysign(
42+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
43+
)
44+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
45+
assert r.shape == ar1.shape
46+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
47+
assert r.sycl_queue == ar1.sycl_queue
48+
49+
ar3 = dpt.ones(sz, dtype=op1_dtype)
50+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
51+
52+
r = dpt.copysign(ar3[::-1], ar4[::2])
53+
assert isinstance(r, dpt.usm_ndarray)
54+
expected = np.copysign(
55+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
56+
)
57+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
58+
assert r.shape == ar3.shape
59+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
60+
61+
62+
@pytest.mark.parametrize("arr_dt", _real_fp_dtypes)
63+
def test_copysign_python_scalar(arr_dt):
64+
q = get_queue_or_skip()
65+
skip_if_dtype_not_supported(arr_dt, q)
66+
67+
X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q)
68+
py_ones = (
69+
bool(1),
70+
int(1),
71+
float(1),
72+
np.float32(1),
73+
ctypes.c_int(1),
74+
)
75+
for sc in py_ones:
76+
R = dpt.copysign(X, sc)
77+
assert isinstance(R, dpt.usm_ndarray)
78+
R = dpt.copysign(sc, X)
79+
assert isinstance(R, dpt.usm_ndarray)
80+
81+
82+
@pytest.mark.parametrize("dt", _real_fp_dtypes)
83+
def test_copysign(dt):
84+
q = get_queue_or_skip()
85+
skip_if_dtype_not_supported(dt, q)
86+
87+
x = dpt.arange(100, dtype=dt, sycl_queue=q)
88+
x[1::2] *= -1
89+
y = dpt.ones(100, dtype=dt, sycl_queue=q)
90+
y[::2] *= -1
91+
res = dpt.copysign(x, y)
92+
expected = dpt.negative(x)
93+
tol = dpt.finfo(dt).resolution
94+
assert dpt.allclose(res, expected, atol=tol, rtol=tol)
95+
96+
97+
def test_copysign_special_values():
98+
get_queue_or_skip()
99+
100+
x1 = dpt.asarray([1.0, 0.0, dpt.nan, dpt.nan], dtype="f4")
101+
y1 = dpt.asarray([-1.0, -0.0, -dpt.nan, -1], dtype="f4")
102+
res = dpt.copysign(x1, y1)
103+
assert dpt.all(dpt.signbit(res))
104+
x2 = dpt.asarray([-1.0, -0.0, -dpt.nan, -dpt.nan], dtype="f4")
105+
res = dpt.copysign(x2, y1)
106+
assert dpt.all(dpt.signbit(res))
107+
y2 = dpt.asarray([0.0, 1.0, dpt.nan, 1.0], dtype="f4")
108+
res = dpt.copysign(x2, y2)
109+
assert not dpt.any(dpt.signbit(res))
110+
res = dpt.copysign(x1, y2)
111+
assert not dpt.any(dpt.signbit(res))

0 commit comments

Comments
 (0)