Skip to content

Commit 44d5df2

Browse files
committed
Implements hypot, negative, positive, pow, and square
1 parent 36a7cd7 commit 44d5df2

File tree

8 files changed

+1604
-12
lines changed

8 files changed

+1604
-12
lines changed

dpctl/tensor/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
floor_divide,
104104
greater,
105105
greater_equal,
106+
hypot,
106107
imag,
107108
isfinite,
108109
isinf,
@@ -112,11 +113,15 @@
112113
log,
113114
log1p,
114115
multiply,
116+
negative,
115117
not_equal,
118+
positive,
119+
pow,
116120
proj,
117121
real,
118122
sin,
119123
sqrt,
124+
square,
120125
subtract,
121126
)
122127
from ._reduction import sum
@@ -204,6 +209,7 @@
204209
"expm1",
205210
"greater",
206211
"greater_equal",
212+
"hypot",
207213
"imag",
208214
"isinf",
209215
"isnan",
@@ -212,12 +218,16 @@
212218
"less_equal",
213219
"log",
214220
"log1p",
221+
"negative",
222+
"positive",
215223
"proj",
216224
"real",
217225
"sin",
218226
"sqrt",
227+
"square",
219228
"divide",
220229
"multiply",
230+
"pow",
221231
"subtract",
222232
"equal",
223233
"not_equal",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,27 @@
615615
)
616616

617617
# U25: ==== NEGATIVE (x)
618-
# FIXME: implement U25
618+
_negative_docstring_ = """
619+
negative(x, out=None, order='K')
620+
621+
Computes the numerical negative elementwise.
622+
Args:
623+
x (usm_ndarray):
624+
Input array, expected to have numeric data type.
625+
out (usm_ndarray):
626+
Output array to populate. Array must have the correct
627+
shape and the expected data type.
628+
order ("C","F","A","K", optional): memory layout of the new
629+
output array, if parameter `out` is `None`.
630+
Default: "K".
631+
Return:
632+
usm_ndarray:
633+
An array containing the element-wise negative values.
634+
"""
635+
636+
negative = UnaryElementwiseFunc(
637+
"negative", ti._negative_result_type, ti._negative, _negative_docstring_
638+
)
619639

620640
# B20: ==== NOT_EQUAL (x1, x2)
621641
_not_equal_docstring_ = """
@@ -647,10 +667,48 @@
647667
)
648668

649669
# U26: ==== POSITIVE (x)
650-
# FIXME: implement U26
670+
_positive_docstring_ = """
671+
positive(x, out=None, order='K')
672+
673+
Computes the numerical positive element-wise.
674+
Args:
675+
x (usm_ndarray):
676+
Input array, expected to have numeric data type.
677+
out (usm_ndarray):
678+
Output array to populate. Array must have the correct
679+
shape and the expected data type.
680+
order ("C","F","A","K", optional): memory layout of the new
681+
output array, if parameter `out` is `None`.
682+
Default: "K".
683+
Return:
684+
usm_ndarray:
685+
An array containing the element-wise positive values.
686+
"""
687+
688+
positive = UnaryElementwiseFunc(
689+
"positive", ti._positive_result_type, ti._positive, _positive_docstring_
690+
)
651691

652692
# B21: ==== POW (x1, x2)
653-
# FIXME: implement B21
693+
_pow_docstring_ = """
694+
pow(x1, x2, out=None, order='K')
695+
696+
Calculates `x1_i` raised to `x2_i` for each element `x1_i` of the input array
697+
`x1` with the respective element `x2_i` of the input array `x2`.
698+
699+
Args:
700+
x1 (usm_ndarray):
701+
First input array, expected to have a numeric data type.
702+
x2 (usm_ndarray):
703+
Second input array, also expected to have a numeric data type.
704+
Returns:
705+
usm_narray:
706+
an array containing the element-wise result. The data type of
707+
the returned array is determined by the Type Promotion Rules.
708+
"""
709+
pow = BinaryElementwiseFunc(
710+
"pow", ti._pow_result_type, ti._pow, _pow_docstring_
711+
)
654712

655713
# U??: ==== PROJ (x)
656714
_proj_docstring = """
@@ -738,7 +796,15 @@
738796
# FIXME: implement U31
739797

740798
# U32: ==== SQUARE (x)
741-
# FIXME: implement U32
799+
_square_docstring_ = """
800+
square(x, out=None, order='K')
801+
802+
Computes `x_i**2` for each element `x_i` for input array `x`.
803+
"""
804+
805+
square = UnaryElementwiseFunc(
806+
"square", ti._square_result_type, ti._square, _square_docstring_
807+
)
742808

743809
# U33: ==== SQRT (x)
744810
_sqrt_docstring_ = """
@@ -806,3 +872,24 @@
806872

807873
# U36: ==== TRUNC (x)
808874
# FIXME: implement U36
875+
876+
# B24: ==== HYPOT (x)
877+
_hypot_docstring_ = """
878+
hypot(x1, x2, out=None, order='K')
879+
880+
Calculates `sqrt(x1_i**2 + x2_i**2)` for each element `x1_i` of input array `x1`
881+
and `x2_i` of input array `x2`.
882+
883+
Args:
884+
x1 (usm_ndarray):
885+
First input array, expected to have a real data type.
886+
x2 (usm_ndarray):
887+
Second input array, also expected to have a real data type.
888+
Returns:
889+
usm_narray:
890+
an array containing the element-wise hypotenuse. The data type
891+
of the returned array is determined by the Type Promotion Rules.
892+
"""
893+
hypot = BinaryElementwiseFunc(
894+
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
895+
)
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
//=== HYPOT.hpp - Binary function HYPOT ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of HYPOT(x1, x2)
23+
/// function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace hypot
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
template <typename argT1, typename argT2, typename resT> struct HypotFunctor
53+
{
54+
55+
using supports_sg_loadstore = std::negation<
56+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
57+
using supports_vec = std::negation<
58+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59+
60+
resT operator()(const argT1 &in1, const argT2 &in2)
61+
{
62+
return std::hypot(in1, in2);
63+
}
64+
65+
template <int vec_sz>
66+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
67+
const sycl::vec<argT2, vec_sz> &in2)
68+
{
69+
auto res = sycl::hypot(in1, in2);
70+
if constexpr (std::is_same_v<resT,
71+
typename decltype(res)::element_type>) {
72+
return res;
73+
}
74+
else {
75+
using dpctl::tensor::type_utils::vec_cast;
76+
77+
return vec_cast<resT, typename decltype(res)::element_type, vec_sz>(
78+
res);
79+
}
80+
}
81+
};
82+
83+
template <typename argT1,
84+
typename argT2,
85+
typename resT,
86+
unsigned int vec_sz = 4,
87+
unsigned int n_vecs = 2>
88+
using HypotContigFunctor =
89+
elementwise_common::BinaryContigFunctor<argT1,
90+
argT2,
91+
resT,
92+
HypotFunctor<argT1, argT2, resT>,
93+
vec_sz,
94+
n_vecs>;
95+
96+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
97+
using HypotStridedFunctor =
98+
elementwise_common::BinaryStridedFunctor<argT1,
99+
argT2,
100+
resT,
101+
IndexerT,
102+
HypotFunctor<argT1, argT2, resT>>;
103+
104+
template <typename T1, typename T2> struct HypotOutputType
105+
{
106+
using value_type = typename std::disjunction< // disjunction is C++17
107+
// feature, supported by DPC++
108+
td_ns::BinaryTypeMapResultEntry<T1,
109+
sycl::half,
110+
T2,
111+
sycl::half,
112+
sycl::half>,
113+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
114+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
115+
td_ns::DefaultResultEntry<void>>::result_type;
116+
};
117+
118+
template <typename argT1,
119+
typename argT2,
120+
typename resT,
121+
unsigned int vec_sz,
122+
unsigned int n_vecs>
123+
class hypot_contig_kernel;
124+
125+
template <typename argTy1, typename argTy2>
126+
sycl::event hypot_contig_impl(sycl::queue exec_q,
127+
size_t nelems,
128+
const char *arg1_p,
129+
py::ssize_t arg1_offset,
130+
const char *arg2_p,
131+
py::ssize_t arg2_offset,
132+
char *res_p,
133+
py::ssize_t res_offset,
134+
const std::vector<sycl::event> &depends = {})
135+
{
136+
return elementwise_common::binary_contig_impl<
137+
argTy1, argTy2, HypotOutputType, HypotContigFunctor,
138+
hypot_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
139+
arg2_offset, res_p, res_offset, depends);
140+
}
141+
142+
template <typename fnT, typename T1, typename T2> struct HypotContigFactory
143+
{
144+
fnT get()
145+
{
146+
if constexpr (std::is_same_v<
147+
typename HypotOutputType<T1, T2>::value_type, void>)
148+
{
149+
fnT fn = nullptr;
150+
return fn;
151+
}
152+
else {
153+
fnT fn = hypot_contig_impl<T1, T2>;
154+
return fn;
155+
}
156+
}
157+
};
158+
159+
template <typename fnT, typename T1, typename T2> struct HypotTypeMapFactory
160+
{
161+
/*! @brief get typeid for output type of std::hypot(T1 x, T2 y) */
162+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
163+
{
164+
using rT = typename HypotOutputType<T1, T2>::value_type;
165+
;
166+
return td_ns::GetTypeid<rT>{}.get();
167+
}
168+
};
169+
170+
template <typename T1, typename T2, typename resT, typename IndexerT>
171+
class hypot_strided_strided_kernel;
172+
173+
template <typename argTy1, typename argTy2>
174+
sycl::event
175+
hypot_strided_impl(sycl::queue exec_q,
176+
size_t nelems,
177+
int nd,
178+
const py::ssize_t *shape_and_strides,
179+
const char *arg1_p,
180+
py::ssize_t arg1_offset,
181+
const char *arg2_p,
182+
py::ssize_t arg2_offset,
183+
char *res_p,
184+
py::ssize_t res_offset,
185+
const std::vector<sycl::event> &depends,
186+
const std::vector<sycl::event> &additional_depends)
187+
{
188+
return elementwise_common::binary_strided_impl<
189+
argTy1, argTy2, HypotOutputType, HypotStridedFunctor,
190+
hypot_strided_strided_kernel>(
191+
exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p,
192+
arg2_offset, res_p, res_offset, depends, additional_depends);
193+
}
194+
195+
template <typename fnT, typename T1, typename T2> struct HypotStridedFactory
196+
{
197+
fnT get()
198+
{
199+
if constexpr (std::is_same_v<
200+
typename HypotOutputType<T1, T2>::value_type, void>)
201+
{
202+
fnT fn = nullptr;
203+
return fn;
204+
}
205+
else {
206+
fnT fn = hypot_strided_impl<T1, T2>;
207+
return fn;
208+
}
209+
}
210+
};
211+
212+
} // namespace hypot
213+
} // namespace kernels
214+
} // namespace tensor
215+
} // namespace dpctl

0 commit comments

Comments
 (0)