Skip to content

Commit 72520b7

Browse files
committed
Adds tests for angle
1 parent 1388283 commit 72520b7

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/angle.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//=== angle.hpp - Unary function ANGLE ------
2-
//*-C++-*--/===//
1+
//=== angle.hpp - Unary function ANGLE ------*-C++-*--/===//
32
//
43
// Data Parallel Control (dpctl)
54
//

dpctl/tests/elementwise/test_angle.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import itertools
18+
19+
import numpy as np
20+
import pytest
21+
22+
import dpctl.tensor as dpt
23+
from dpctl.tensor._type_utils import _can_cast
24+
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
25+
26+
from .utils import _all_dtypes, _complex_fp_dtypes, _no_complex_dtypes
27+
28+
29+
@pytest.mark.parametrize("dtype", _all_dtypes)
30+
def test_angle_out_type(dtype):
31+
q = get_queue_or_skip()
32+
skip_if_dtype_not_supported(dtype, q)
33+
34+
x = dpt.asarray(1, dtype=dtype, sycl_queue=q)
35+
dt = dpt.dtype(dtype)
36+
dev = q.sycl_device
37+
_fp16 = dev.has_aspect_fp16
38+
_fp64 = dev.has_aspect_fp64
39+
if _can_cast(dt, dpt.complex64, _fp16, _fp64):
40+
assert dpt.angle(x).dtype == dpt.float32
41+
else:
42+
assert dpt.angle(x).dtype == dpt.float64
43+
44+
45+
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:])
46+
def test_angle_real(dtype):
47+
q = get_queue_or_skip()
48+
skip_if_dtype_not_supported(dtype, q)
49+
50+
x = dpt.arange(10, dtype=dtype, sycl_queue=q)
51+
r = dpt.angle(x)
52+
53+
assert dpt.all(r == 0)
54+
55+
56+
@pytest.mark.parametrize("dtype", _complex_fp_dtypes)
57+
def test_angle_complex(dtype):
58+
q = get_queue_or_skip()
59+
skip_if_dtype_not_supported(dtype, q)
60+
61+
tol = 8 * dpt.finfo(dtype).resolution
62+
vals = dpt.pi * dpt.arange(10, dtype=dpt.finfo(dtype).dtype, sycl_queue=q)
63+
64+
x = dpt.zeros(10, dtype=dtype, sycl_queue=q)
65+
66+
x.imag[...] = vals
67+
r = dpt.angle(x)
68+
expected = dpt.atan2(x.imag, x.real)
69+
assert dpt.allclose(r, expected, atol=tol, rtol=tol)
70+
71+
x.real[...] += dpt.pi
72+
r = dpt.angle(x)
73+
expected = dpt.atan2(x.imag, x.real)
74+
assert dpt.allclose(r, expected, atol=tol, rtol=tol)
75+
76+
77+
@pytest.mark.parametrize("dtype", ["c8", "c16"])
78+
def test_angle_special_cases(dtype):
79+
q = get_queue_or_skip()
80+
skip_if_dtype_not_supported(dtype, q)
81+
82+
vals = [np.nan, -np.nan, np.inf, -np.inf, +0.0, -0.0]
83+
vals = [complex(*val) for val in itertools.product(vals, repeat=2)]
84+
85+
x = dpt.asarray(vals, dtype=dtype, sycl_queue=q)
86+
87+
r = dpt.angle(x)
88+
expected = dpt.atan2(x.imag, x.real)
89+
90+
tol = 8 * dpt.finfo(dtype).resolution
91+
92+
assert dpt.allclose(r, expected, atol=tol, rtol=tol, equal_nan=True)

0 commit comments

Comments
 (0)