Skip to content

Commit 70b56e3

Browse files
committed
Tests for negative, positive, pow, and square
1 parent 883fd26 commit 70b56e3

File tree

4 files changed

+405
-0
lines changed

4 files changed

+405
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import itertools
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl.tensor as dpt
23+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
24+
25+
from .utils import _all_dtypes, _usm_types
26+
27+
28+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
29+
def test_negative_out_type(dtype):
30+
q = get_queue_or_skip()
31+
skip_if_dtype_not_supported(dtype, q)
32+
33+
arg_dt = np.dtype(dtype)
34+
X = dpt.asarray(0, dtype=arg_dt, sycl_queue=q)
35+
assert dpt.negative(X).dtype == arg_dt
36+
37+
r = dpt.empty_like(X, dtype=arg_dt)
38+
dpt.negative(X, out=r)
39+
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.negative(X)))
40+
41+
42+
@pytest.mark.parametrize("usm_type", _usm_types)
43+
def test_negative_usm_type(usm_type):
44+
q = get_queue_or_skip()
45+
46+
arg_dt = np.dtype("i4")
47+
input_shape = (10, 10, 10, 10)
48+
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
49+
X[..., 0::2] = 1
50+
X[..., 1::2] = 0
51+
52+
Y = dpt.negative(X)
53+
assert Y.usm_type == X.usm_type
54+
assert Y.sycl_queue == X.sycl_queue
55+
assert Y.flags.c_contiguous
56+
57+
expected_Y = np.negative(dpt.asnumpy(X))
58+
assert np.allclose(dpt.asnumpy(Y), expected_Y)
59+
60+
61+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
62+
def test_negative_order(dtype):
63+
q = get_queue_or_skip()
64+
skip_if_dtype_not_supported(dtype, q)
65+
66+
arg_dt = np.dtype(dtype)
67+
input_shape = (10, 10, 10, 10)
68+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
69+
X[..., 0::2] = 1
70+
X[..., 1::2] = 0
71+
72+
for ord in ["C", "F", "A", "K"]:
73+
for perms in itertools.permutations(range(4)):
74+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
75+
Y = dpt.negative(U, order=ord)
76+
expected_Y = np.negative(np.ones(Y.shape, dtype=Y.dtype))
77+
expected_Y[..., 1::2] = 0
78+
expected_Y = np.transpose(expected_Y, perms)
79+
assert np.allclose(dpt.asnumpy(Y), expected_Y)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import itertools
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl.tensor as dpt
23+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
24+
25+
from .utils import _all_dtypes, _usm_types
26+
27+
28+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
29+
def test_positive_out_type(dtype):
30+
q = get_queue_or_skip()
31+
skip_if_dtype_not_supported(dtype, q)
32+
33+
arg_dt = np.dtype(dtype)
34+
X = dpt.asarray(0, dtype=arg_dt, sycl_queue=q)
35+
assert dpt.positive(X).dtype == arg_dt
36+
37+
r = dpt.empty_like(X, dtype=arg_dt)
38+
dpt.positive(X, out=r)
39+
assert np.allclose(dpt.asnumpy(r), dpt.asnumpy(dpt.positive(X)))
40+
41+
42+
@pytest.mark.parametrize("usm_type", _usm_types)
43+
def test_positive_usm_type(usm_type):
44+
q = get_queue_or_skip()
45+
46+
arg_dt = np.dtype("i4")
47+
input_shape = (10, 10, 10, 10)
48+
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
49+
X[..., 0::2] = 1
50+
X[..., 1::2] = 0
51+
52+
Y = dpt.positive(X)
53+
assert Y.usm_type == X.usm_type
54+
assert Y.sycl_queue == X.sycl_queue
55+
assert Y.flags.c_contiguous
56+
57+
expected_Y = dpt.asnumpy(X)
58+
assert np.allclose(dpt.asnumpy(Y), expected_Y)
59+
60+
61+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
62+
def test_positive_order(dtype):
63+
q = get_queue_or_skip()
64+
skip_if_dtype_not_supported(dtype, q)
65+
66+
arg_dt = np.dtype(dtype)
67+
input_shape = (10, 10, 10, 10)
68+
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
69+
X[..., 0::2] = 1
70+
X[..., 1::2] = 0
71+
72+
for ord in ["C", "F", "A", "K"]:
73+
for perms in itertools.permutations(range(4)):
74+
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
75+
Y = dpt.positive(U, order=ord)
76+
expected_Y = np.ones(Y.shape, dtype=Y.dtype)
77+
expected_Y[..., 1::2] = 0
78+
expected_Y = np.transpose(expected_Y, perms)
79+
assert np.allclose(dpt.asnumpy(Y), expected_Y)

dpctl/tests/elementwise/test_pow.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import ctypes
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl
23+
import dpctl.tensor as dpt
24+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
26+
from .utils import _all_dtypes, _compare_dtypes, _usm_types
27+
28+
29+
@pytest.mark.parametrize("op1_dtype", _all_dtypes[1:])
30+
@pytest.mark.parametrize("op2_dtype", _all_dtypes[1:])
31+
def test_power_dtype_matrix(op1_dtype, op2_dtype):
32+
q = get_queue_or_skip()
33+
skip_if_dtype_not_supported(op1_dtype, q)
34+
skip_if_dtype_not_supported(op2_dtype, q)
35+
36+
sz = 127
37+
ar1 = dpt.ones(sz, dtype=op1_dtype)
38+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
39+
40+
r = dpt.pow(ar1, ar2)
41+
assert isinstance(r, dpt.usm_ndarray)
42+
expected = np.power(
43+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
44+
)
45+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
46+
assert r.shape == ar1.shape
47+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
48+
assert r.sycl_queue == ar1.sycl_queue
49+
50+
ar3 = dpt.ones(sz, dtype=op1_dtype)
51+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
52+
53+
r = dpt.pow(ar3[::-1], ar4[::2])
54+
assert isinstance(r, dpt.usm_ndarray)
55+
expected = np.power(
56+
np.ones(1, dtype=op1_dtype), np.ones(1, dtype=op2_dtype)
57+
)
58+
assert _compare_dtypes(r.dtype, expected.dtype, sycl_queue=q)
59+
assert r.shape == ar3.shape
60+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
61+
62+
63+
@pytest.mark.parametrize("op1_usm_type", _usm_types)
64+
@pytest.mark.parametrize("op2_usm_type", _usm_types)
65+
def test_power_usm_type_matrix(op1_usm_type, op2_usm_type):
66+
get_queue_or_skip()
67+
68+
sz = 128
69+
ar1 = dpt.ones(sz, dtype="i4", usm_type=op1_usm_type)
70+
ar2 = dpt.ones_like(ar1, dtype="i4", usm_type=op2_usm_type)
71+
72+
r = dpt.pow(ar1, ar2)
73+
assert isinstance(r, dpt.usm_ndarray)
74+
expected_usm_type = dpctl.utils.get_coerced_usm_type(
75+
(op1_usm_type, op2_usm_type)
76+
)
77+
assert r.usm_type == expected_usm_type
78+
79+
80+
def test_pow_order():
81+
get_queue_or_skip()
82+
83+
ar1 = dpt.ones((20, 20), dtype="i4", order="C")
84+
ar2 = dpt.ones((20, 20), dtype="i4", order="C")
85+
r1 = dpt.pow(ar1, ar2, order="C")
86+
assert r1.flags.c_contiguous
87+
r2 = dpt.pow(ar1, ar2, order="F")
88+
assert r2.flags.f_contiguous
89+
r3 = dpt.pow(ar1, ar2, order="A")
90+
assert r3.flags.c_contiguous
91+
r4 = dpt.pow(ar1, ar2, order="K")
92+
assert r4.flags.c_contiguous
93+
94+
ar1 = dpt.ones((20, 20), dtype="i4", order="F")
95+
ar2 = dpt.ones((20, 20), dtype="i4", order="F")
96+
r1 = dpt.pow(ar1, ar2, order="C")
97+
assert r1.flags.c_contiguous
98+
r2 = dpt.pow(ar1, ar2, order="F")
99+
assert r2.flags.f_contiguous
100+
r3 = dpt.pow(ar1, ar2, order="A")
101+
assert r3.flags.f_contiguous
102+
r4 = dpt.pow(ar1, ar2, order="K")
103+
assert r4.flags.f_contiguous
104+
105+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
106+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
107+
r4 = dpt.pow(ar1, ar2, order="K")
108+
assert r4.strides == (20, -1)
109+
110+
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
111+
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
112+
r4 = dpt.pow(ar1, ar2, order="K")
113+
assert r4.strides == (-1, 20)
114+
115+
116+
def test_pow_broadcasting():
117+
get_queue_or_skip()
118+
119+
v = dpt.arange(1, 6, dtype="i4")
120+
m = dpt.full((100, 5), 2, dtype="i4")
121+
122+
r = dpt.pow(m, v)
123+
124+
expected = np.power(
125+
np.full((100, 5), 2, dtype="i4"), np.arange(1, 6, dtype="i4")
126+
)
127+
assert (dpt.asnumpy(r) == expected.astype(r.dtype)).all()
128+
129+
r2 = dpt.pow(v, m)
130+
expected2 = np.power(
131+
np.arange(1, 6, dtype="i4"), np.full((100, 5), 2, dtype="i4")
132+
)
133+
assert (dpt.asnumpy(r2) == expected2.astype(r2.dtype)).all()
134+
135+
136+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
137+
def test_pow_python_scalar(arr_dt):
138+
q = get_queue_or_skip()
139+
skip_if_dtype_not_supported(arr_dt, q)
140+
141+
X = dpt.ones((10, 10), dtype=arr_dt, sycl_queue=q)
142+
py_ones = (
143+
bool(1),
144+
int(1),
145+
float(1),
146+
complex(1),
147+
np.float32(1),
148+
ctypes.c_int(1),
149+
)
150+
for sc in py_ones:
151+
R = dpt.pow(X, sc)
152+
assert isinstance(R, dpt.usm_ndarray)
153+
R = dpt.pow(sc, X)
154+
assert isinstance(R, dpt.usm_ndarray)

0 commit comments

Comments
 (0)