Skip to content

Commit b736052

Browse files
committed
impl_real_imag_conj
1 parent 165727a commit b736052

File tree

9 files changed

+1259
-183
lines changed

9 files changed

+1259
-183
lines changed

dpctl/tensor/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,21 @@
9494
from ._elementwise_funcs import (
9595
abs,
9696
add,
97+
conj,
9798
cos,
9899
divide,
99100
equal,
100101
exp,
101102
expm1,
103+
imag,
102104
isfinite,
103105
isinf,
104106
isnan,
105107
log,
106108
log1p,
107109
multiply,
110+
proj,
111+
real,
108112
sin,
109113
sqrt,
110114
subtract,
@@ -188,14 +192,18 @@
188192
"inf",
189193
"abs",
190194
"add",
195+
"conj",
191196
"cos",
192197
"exp",
193198
"expm1",
199+
"imag",
194200
"isinf",
195201
"isnan",
196202
"isfinite",
197203
"log",
198204
"log1p",
205+
"proj",
206+
"real",
199207
"sin",
200208
"sqrt",
201209
"exp",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,29 @@
113113
# FIXME: implement U09
114114

115115
# U10: ==== CONJ (x)
116-
# FIXME: implement U10
116+
_conj_docstring = """
117+
conj(x, out=None, order='K')
118+
119+
Computes conjugate of each element `x_i` for input array `x`.
120+
121+
Args:
122+
x (usm_ndarray):
123+
Input array, expected to have numeric data type.
124+
out ({None, usm_ndarray}, optional):
125+
Output array to populate.
126+
Array have the correct shape and the expected data type.
127+
order ("C","F","A","K", optional):
128+
Memory layout of the newly output array, if parameter `out` is `None`.
129+
Default: "K".
130+
Returns:
131+
usm_narray:
132+
An array containing the element-wise conjugate values. The data type
133+
of the returned array is determined by the Type Promotion Rules.
134+
"""
135+
136+
conj = UnaryElementwiseFunc(
137+
"conj", ti._conj_result_type, ti._conj, _conj_docstring
138+
)
117139

118140
# U11: ==== COS (x)
119141
_cos_docstring = """
@@ -257,7 +279,30 @@
257279
# FIXME: implement B12
258280

259281
# U16: ==== IMAG (x)
260-
# FIXME: implement U16
282+
_imag_docstring = """
283+
imag(x, out=None, order='K')
284+
285+
Computes imaginary part of each element `x_i` for input array `x`.
286+
287+
Args:
288+
x (usm_ndarray):
289+
Input array, expected to have numeric data type.
290+
out ({None, usm_ndarray}, optional):
291+
Output array to populate.
292+
Array have the correct shape and the expected data type.
293+
order ("C","F","A","K", optional):
294+
Memory layout of the newly output array, if parameter `out` is `None`.
295+
Default: "K".
296+
Returns:
297+
usm_narray:
298+
An array containing the element-wise imaginary component of input.
299+
The data type of the returned array is determined
300+
by the Type Promotion Rules.
301+
"""
302+
303+
imag = UnaryElementwiseFunc(
304+
"imag", ti._imag_result_type, ti._imag, _imag_docstring
305+
)
261306

262307
# U17: ==== ISFINITE (x)
263308
_isfinite_docstring_ = """
@@ -443,8 +488,55 @@
443488
# B21: ==== POW (x1, x2)
444489
# FIXME: implement B21
445490

491+
# U??: ==== PROJ (x)
492+
_proj_docstring = """
493+
proj(x, out=None, order='K')
494+
495+
Computes projection of each element `x_i` for input array `x`.
496+
497+
Args:
498+
x (usm_ndarray):
499+
Input array, expected to have numeric data type.
500+
out ({None, usm_ndarray}, optional):
501+
Output array to populate.
502+
Array have the correct shape and the expected data type.
503+
order ("C","F","A","K", optional):
504+
Memory layout of the newly output array, if parameter `out` is `None`.
505+
Default: "K".
506+
Returns:
507+
usm_narray:
508+
An array containing the element-wise projection. The data
509+
type of the returned array is determined by the Type Promotion Rules.
510+
"""
511+
512+
proj = UnaryElementwiseFunc(
513+
"proj", ti._proj_result_type, ti._proj, _proj_docstring
514+
)
515+
446516
# U27: ==== REAL (x)
447-
# FIXME: implement U27
517+
_real_docstring = """
518+
real(x, out=None, order='K')
519+
520+
Computes real part of each element `x_i` for input array `x`.
521+
522+
Args:
523+
x (usm_ndarray):
524+
Input array, expected to have numeric data type.
525+
out ({None, usm_ndarray}, optional):
526+
Output array to populate.
527+
Array have the correct shape and the expected data type.
528+
order ("C","F","A","K", optional):
529+
Memory layout of the newly output array, if parameter `out` is `None`.
530+
Default: "K".
531+
Returns:
532+
usm_narray:
533+
An array containing the element-wise real component of input. The data
534+
type of the returned array is determined by the Type Promotion Rules.
535+
"""
536+
537+
real = UnaryElementwiseFunc(
538+
"real", ti._real_result_type, ti._real, _real_docstring
539+
)
448540

449541
# B22: ==== REMAINDER (x1, x2)
450542
# FIXME: implement B22
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
//=== conj.hpp - Unary function CONJ ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of CONJ(x) function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cmath>
29+
#include <complex>
30+
#include <cstddef>
31+
#include <cstdint>
32+
#include <type_traits>
33+
34+
#include "kernels/elementwise_functions/common.hpp"
35+
36+
#include "utils/offset_utils.hpp"
37+
#include "utils/type_dispatch.hpp"
38+
#include "utils/type_utils.hpp"
39+
#include <pybind11/pybind11.h>
40+
41+
namespace dpctl
42+
{
43+
namespace tensor
44+
{
45+
namespace kernels
46+
{
47+
namespace conj
48+
{
49+
50+
namespace py = pybind11;
51+
namespace td_ns = dpctl::tensor::type_dispatch;
52+
53+
using dpctl::tensor::type_utils::is_complex;
54+
55+
template <typename argT, typename resT> struct ConjFunctor
56+
{
57+
58+
// is function constant for given argT
59+
using is_constant = typename std::false_type;
60+
// constant value, if constant
61+
// constexpr resT constant_value = resT{};
62+
// is function defined for sycl::vec
63+
using supports_vec = typename std::false_type;
64+
// do both argTy and resTy support sugroup store/load operation
65+
using supports_sg_loadstore = typename std::negation<
66+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
67+
68+
resT operator()(const argT &in)
69+
{
70+
if constexpr (is_complex<argT>::value) {
71+
return std::conj(in);
72+
}
73+
else {
74+
if constexpr (!std::is_same_v<argT, bool>)
75+
static_assert(std::is_same_v<resT, argT>);
76+
return in;
77+
}
78+
}
79+
};
80+
81+
template <typename argTy,
82+
typename resTy = argTy,
83+
unsigned int vec_sz = 4,
84+
unsigned int n_vecs = 2>
85+
using ConjContigFunctor = elementwise_common::
86+
UnaryContigFunctor<argTy, resTy, ConjFunctor<argTy, resTy>, vec_sz, n_vecs>;
87+
88+
template <typename argTy, typename resTy, typename IndexerT>
89+
using ConjStridedFunctor = elementwise_common::
90+
UnaryStridedFunctor<argTy, resTy, IndexerT, ConjFunctor<argTy, resTy>>;
91+
92+
template <typename T> struct ConjOutputType
93+
{
94+
using value_type = typename std::disjunction< // disjunction is C++17
95+
// feature, supported by DPC++
96+
td_ns::TypeMapResultEntry<T, bool, int8_t>,
97+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
98+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
99+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
100+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
101+
td_ns::TypeMapResultEntry<T, std::int8_t>,
102+
td_ns::TypeMapResultEntry<T, std::int16_t>,
103+
td_ns::TypeMapResultEntry<T, std::int32_t>,
104+
td_ns::TypeMapResultEntry<T, std::int64_t>,
105+
td_ns::TypeMapResultEntry<T, sycl::half>,
106+
td_ns::TypeMapResultEntry<T, float>,
107+
td_ns::TypeMapResultEntry<T, double>,
108+
td_ns::TypeMapResultEntry<T, std::complex<float>>,
109+
td_ns::TypeMapResultEntry<T, std::complex<double>>,
110+
td_ns::DefaultResultEntry<void>>::result_type;
111+
};
112+
113+
template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
114+
class conj_contig_kernel;
115+
116+
template <typename argTy>
117+
sycl::event conj_contig_impl(sycl::queue exec_q,
118+
size_t nelems,
119+
const char *arg_p,
120+
char *res_p,
121+
const std::vector<sycl::event> &depends = {})
122+
{
123+
return elementwise_common::unary_contig_impl<
124+
argTy, ConjOutputType, ConjContigFunctor, conj_contig_kernel>(
125+
exec_q, nelems, arg_p, res_p, depends);
126+
}
127+
128+
template <typename fnT, typename T> struct ConjContigFactory
129+
{
130+
fnT get()
131+
{
132+
if constexpr (std::is_same_v<typename ConjOutputType<T>::value_type,
133+
void>) {
134+
fnT fn = nullptr;
135+
return fn;
136+
}
137+
else {
138+
fnT fn = conj_contig_impl<T>;
139+
return fn;
140+
}
141+
}
142+
};
143+
144+
template <typename fnT, typename T> struct ConjTypeMapFactory
145+
{
146+
/*! @brief get typeid for output type of std::conj(T x) */
147+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
148+
{
149+
using rT = typename ConjOutputType<T>::value_type;
150+
return td_ns::GetTypeid<rT>{}.get();
151+
}
152+
};
153+
154+
template <typename T1, typename T2, typename T3> class conj_strided_kernel;
155+
156+
template <typename argTy>
157+
sycl::event
158+
conj_strided_impl(sycl::queue exec_q,
159+
size_t nelems,
160+
int nd,
161+
const py::ssize_t *shape_and_strides,
162+
const char *arg_p,
163+
py::ssize_t arg_offset,
164+
char *res_p,
165+
py::ssize_t res_offset,
166+
const std::vector<sycl::event> &depends,
167+
const std::vector<sycl::event> &additional_depends)
168+
{
169+
return elementwise_common::unary_strided_impl<
170+
argTy, ConjOutputType, ConjStridedFunctor, conj_strided_kernel>(
171+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
172+
res_offset, depends, additional_depends);
173+
}
174+
175+
template <typename fnT, typename T> struct ConjStridedFactory
176+
{
177+
fnT get()
178+
{
179+
if constexpr (std::is_same_v<typename ConjOutputType<T>::value_type,
180+
void>) {
181+
fnT fn = nullptr;
182+
return fn;
183+
}
184+
else {
185+
fnT fn = conj_strided_impl<T>;
186+
return fn;
187+
}
188+
}
189+
};
190+
191+
} // namespace conj
192+
} // namespace kernels
193+
} // namespace tensor
194+
} // namespace dpctl

0 commit comments

Comments
 (0)