Skip to content

Commit bf2d1ff

Browse files
Merge pull request #1317 from IntelPython/implement-atan2
Implement tensor.atan2 and tensor.signbit
2 parents 438ded5 + 2e9e185 commit bf2d1ff

File tree

7 files changed

+1191
-4
lines changed

7 files changed

+1191
-4
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
asin,
100100
asinh,
101101
atan,
102+
atan2,
102103
atanh,
103104
bitwise_and,
104105
bitwise_invert,
@@ -144,6 +145,7 @@
144145
remainder,
145146
round,
146147
sign,
148+
signbit,
147149
sin,
148150
sinh,
149151
sqrt,
@@ -237,6 +239,7 @@
237239
"asin",
238240
"asinh",
239241
"atan",
242+
"atan2",
240243
"atanh",
241244
"bitwise_and",
242245
"bitwise_invert",
@@ -282,6 +285,7 @@
282285
"remainder",
283286
"round",
284287
"sign",
288+
"signbit",
285289
"sin",
286290
"sinh",
287291
"sqrt",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,36 @@
209209
)
210210

211211
# B02: ===== ATAN2 (x1, x2)
212-
# FIXME: implemetn B02
212+
_atan2_docstring_ = """
213+
atan2(x1, x2, out=None, order='K')
214+
215+
Calculates the inverse tangent of the quotient `x1_i/x2_i` for each element
216+
`x1_i` of the input array `x1` with the respective element `x2_i` of the
217+
input array `x2`. Each element-wise result is expressed in radians.
218+
219+
Args:
220+
x1 (usm_ndarray):
221+
First input array, expected to have a real-valued floating-point
222+
data type.
223+
x2 (usm_ndarray):
224+
Second input array, also expected to have a real-valued
225+
floating-point data type.
226+
out ({None, usm_ndarray}, optional):
227+
Output array to populate.
228+
Array have the correct shape and the expected data type.
229+
order ("C","F","A","K", optional):
230+
Memory layout of the newly output array, if parameter `out` is `None`.
231+
Default: "K".
232+
Returns:
233+
usm_narray:
234+
An array containing the inverse tangent of the quotient `x1`/`x2`.
235+
The returned array must have a real-valued floating-point data type
236+
determined by Type Promotion Rules.
237+
"""
238+
239+
atan2 = BinaryElementwiseFunc(
240+
"atan2", ti._atan2_result_type, ti._atan2, _atan2_docstring_
241+
)
213242

214243
# U07: ===== ATANH (x)
215244
_atanh_docstring = """
@@ -1404,6 +1433,32 @@
14041433
"sign", ti._sign_result_type, ti._sign, _sign_docstring
14051434
)
14061435

1436+
# ==== SIGNBIT (x)
1437+
_signbit_docstring = """
1438+
signbit(x, out=None, order='K')
1439+
1440+
Computes an indication of whether the sign bit of each element `x_i` of
1441+
input array `x` is set.
1442+
1443+
Args:
1444+
x (usm_ndarray):
1445+
Input array, expected to have a numeric data type.
1446+
out ({None, usm_ndarray}, optional):
1447+
Output array to populate.
1448+
Array have the correct shape and the expected data type.
1449+
order ("C","F","A","K", optional):
1450+
Memory layout of the newly output array, if parameter `out` is `None`.
1451+
Default: "K".
1452+
Returns:
1453+
usm_narray:
1454+
An array containing the element-wise results. The returned array
1455+
must have a data type of `bool`.
1456+
"""
1457+
1458+
signbit = UnaryElementwiseFunc(
1459+
"signbit", ti._signbit_result_type, ti._signbit, _signbit_docstring
1460+
)
1461+
14071462
# U30: ==== SIN (x)
14081463
_sin_docstring = """
14091464
sin(x, out=None, order='K')
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
//=== ATAN2.hpp - Binary function ATAN2 ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of ATAN2(x1, x2)
23+
/// function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace atan2
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
template <typename argT1, typename argT2, typename resT> struct Atan2Functor
53+
{
54+
55+
using supports_sg_loadstore = std::true_type;
56+
using supports_vec = std::false_type;
57+
58+
resT operator()(const argT1 &in1, const argT2 &in2)
59+
{
60+
if (std::isinf(in2) && !std::signbit(in2)) {
61+
if (std::isfinite(in1)) {
62+
return std::copysign(resT(0), in1);
63+
}
64+
}
65+
return std::atan2(in1, in2);
66+
}
67+
};
68+
69+
template <typename argT1,
70+
typename argT2,
71+
typename resT,
72+
unsigned int vec_sz = 4,
73+
unsigned int n_vecs = 2>
74+
using Atan2ContigFunctor =
75+
elementwise_common::BinaryContigFunctor<argT1,
76+
argT2,
77+
resT,
78+
Atan2Functor<argT1, argT2, resT>,
79+
vec_sz,
80+
n_vecs>;
81+
82+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
83+
using Atan2StridedFunctor =
84+
elementwise_common::BinaryStridedFunctor<argT1,
85+
argT2,
86+
resT,
87+
IndexerT,
88+
Atan2Functor<argT1, argT2, resT>>;
89+
90+
template <typename T1, typename T2> struct Atan2OutputType
91+
{
92+
using value_type = typename std::disjunction< // disjunction is C++17
93+
// feature, supported by DPC++
94+
td_ns::BinaryTypeMapResultEntry<T1,
95+
sycl::half,
96+
T2,
97+
sycl::half,
98+
sycl::half>,
99+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
100+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
101+
td_ns::DefaultResultEntry<void>>::result_type;
102+
};
103+
104+
template <typename argT1,
105+
typename argT2,
106+
typename resT,
107+
unsigned int vec_sz,
108+
unsigned int n_vecs>
109+
class atan2_contig_kernel;
110+
111+
template <typename argTy1, typename argTy2>
112+
sycl::event atan2_contig_impl(sycl::queue exec_q,
113+
size_t nelems,
114+
const char *arg1_p,
115+
py::ssize_t arg1_offset,
116+
const char *arg2_p,
117+
py::ssize_t arg2_offset,
118+
char *res_p,
119+
py::ssize_t res_offset,
120+
const std::vector<sycl::event> &depends = {})
121+
{
122+
return elementwise_common::binary_contig_impl<
123+
argTy1, argTy2, Atan2OutputType, Atan2ContigFunctor,
124+
atan2_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
125+
arg2_offset, res_p, res_offset, depends);
126+
}
127+
128+
template <typename fnT, typename T1, typename T2> struct Atan2ContigFactory
129+
{
130+
fnT get()
131+
{
132+
if constexpr (std::is_same_v<
133+
typename Atan2OutputType<T1, T2>::value_type, void>)
134+
{
135+
fnT fn = nullptr;
136+
return fn;
137+
}
138+
else {
139+
fnT fn = atan2_contig_impl<T1, T2>;
140+
return fn;
141+
}
142+
}
143+
};
144+
145+
template <typename fnT, typename T1, typename T2> struct Atan2TypeMapFactory
146+
{
147+
/*! @brief get typeid for output type of std::hypot(T1 x, T2 y) */
148+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
149+
{
150+
using rT = typename Atan2OutputType<T1, T2>::value_type;
151+
;
152+
return td_ns::GetTypeid<rT>{}.get();
153+
}
154+
};
155+
156+
template <typename T1, typename T2, typename resT, typename IndexerT>
157+
class atan2_strided_kernel;
158+
159+
template <typename argTy1, typename argTy2>
160+
sycl::event
161+
atan2_strided_impl(sycl::queue exec_q,
162+
size_t nelems,
163+
int nd,
164+
const py::ssize_t *shape_and_strides,
165+
const char *arg1_p,
166+
py::ssize_t arg1_offset,
167+
const char *arg2_p,
168+
py::ssize_t arg2_offset,
169+
char *res_p,
170+
py::ssize_t res_offset,
171+
const std::vector<sycl::event> &depends,
172+
const std::vector<sycl::event> &additional_depends)
173+
{
174+
return elementwise_common::binary_strided_impl<
175+
argTy1, argTy2, Atan2OutputType, Atan2StridedFunctor,
176+
atan2_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
177+
arg1_offset, arg2_p, arg2_offset, res_p,
178+
res_offset, depends, additional_depends);
179+
}
180+
181+
template <typename fnT, typename T1, typename T2> struct Atan2StridedFactory
182+
{
183+
fnT get()
184+
{
185+
if constexpr (std::is_same_v<
186+
typename Atan2OutputType<T1, T2>::value_type, void>)
187+
{
188+
fnT fn = nullptr;
189+
return fn;
190+
}
191+
else {
192+
fnT fn = atan2_strided_impl<T1, T2>;
193+
return fn;
194+
}
195+
}
196+
};
197+
198+
} // namespace atan2
199+
} // namespace kernels
200+
} // namespace tensor
201+
} // namespace dpctl

0 commit comments

Comments
 (0)