Skip to content

Commit 6d71e46

Browse files
Implemented isinf, isfinite, reused templates to define Contig and Strided unary functors
Added tests,
1 parent 872f372 commit 6d71e46

File tree

10 files changed

+1091
-284
lines changed

10 files changed

+1091
-284
lines changed

dpctl/tensor/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
from dpctl.tensor._usmarray import usm_ndarray
9191

9292
from ._constants import e, inf, nan, newaxis, pi
93-
from ._elementwise_funcs import abs, add, cos, isnan
93+
from ._elementwise_funcs import abs, add, cos, isfinite, isinf, isnan
9494

9595
__all__ = [
9696
"Device",
@@ -168,5 +168,7 @@
168168
"abs",
169169
"add",
170170
"cos",
171+
"isinf",
171172
"isnan",
173+
"isfinite",
172174
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,32 @@
4242

4343
cos = UnaryElementwiseFunc("cos", ti._cos_result_type, ti._cos, _cos_docstring)
4444

45+
# ISFINITE
46+
47+
_isfinite_docstring_ = """
48+
Computes if every element of input array is a finite number.
49+
"""
50+
51+
isfinite = UnaryElementwiseFunc(
52+
"isfinite", ti._isfinite_result_type, ti._isfinite, _isfinite_docstring_
53+
)
54+
4555
# ISNAN
4656

4757
_isnan_docstring_ = """
48-
Computes if ever element of input array is a NaN.
58+
Computes if every element of input array is a NaN.
4959
"""
5060

5161
isnan = UnaryElementwiseFunc(
5262
"isnan", ti._isnan_result_type, ti._isnan, _isnan_docstring_
5363
)
64+
65+
# ISINF
66+
67+
_isinf_docstring_ = """
68+
Computes if every element of input array is an infinity.
69+
"""
70+
71+
isinf = UnaryElementwiseFunc(
72+
"isinf", ti._isinf_result_type, ti._isinf, _isinf_docstring_
73+
)

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 25 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
#pragma once
22
#include <CL/sycl.hpp>
33

4+
#include "kernels/elementwise_functions/common.hpp"
5+
46
#include "utils/offset_utils.hpp"
57
#include "utils/type_dispatch.hpp"
68
#include "utils/type_utils.hpp"
79
#include <pybind11/pybind11.h>
810

11+
#include <iostream>
12+
913
namespace dpctl
1014
{
1115
namespace tensor
@@ -18,120 +22,40 @@ namespace abs
1822
namespace py = pybind11;
1923
namespace td_ns = dpctl::tensor::type_dispatch;
2024

21-
template <typename argT,
22-
typename resT = argT,
23-
unsigned int vec_sz = 4,
24-
unsigned int n_vecs = 2>
25-
struct AbsContigFunctor
25+
using dpctl::tensor::type_utils::is_complex;
26+
27+
template <typename argT, typename resT> struct AbsFunctor
2628
{
27-
private:
28-
const argT *in = nullptr;
29-
resT *out = nullptr;
30-
const size_t nelems_;
3129

32-
public:
33-
AbsContigFunctor(const argT *inp, resT *res, const size_t n_elems)
34-
: in(inp), out(res), nelems_(n_elems)
35-
{
36-
}
30+
using is_constant = typename std::false_type;
31+
// constexpr resT constant_value = resT{};
32+
using supports_vec = typename std::false_type;
33+
using supports_sg_loadstore = typename std::negation<
34+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
3735

38-
void operator()(sycl::nd_item<1> ndit) const
36+
resT operator()(const argT &x)
3937
{
40-
/* Each work-item processes vec_sz elements, contiguous in memory */
41-
/* NOTE: vec_sz must divide sg.max_local_range()[0] */
4238

4339
if constexpr (std::is_same_v<argT, bool> ||
4440
(std::is_integral<argT>::value &&
4541
std::is_unsigned<argT>::value))
4642
{
4743
static_assert(std::is_same_v<resT, argT>);
48-
49-
auto sg = ndit.get_sub_group();
50-
std::uint8_t sgSize = sg.get_local_range()[0];
51-
std::uint8_t max_sgSize = sg.get_max_local_range()[0];
52-
size_t base = n_vecs * vec_sz *
53-
(ndit.get_group(0) * ndit.get_local_range(0) +
54-
sg.get_group_id()[0] * max_sgSize);
55-
56-
if (base + n_vecs * vec_sz * sgSize < nelems_ &&
57-
sgSize == max_sgSize) {
58-
using in_ptrT =
59-
sycl::multi_ptr<const argT,
60-
sycl::access::address_space::global_space>;
61-
using out_ptrT =
62-
sycl::multi_ptr<resT,
63-
sycl::access::address_space::global_space>;
64-
sycl::vec<argT, vec_sz> arg_vec;
65-
66-
#pragma unroll
67-
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
68-
arg_vec = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
69-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
70-
arg_vec);
71-
}
72-
}
73-
else {
74-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
75-
k += sgSize) {
76-
out[k] = in[k];
77-
}
78-
}
44+
return x;
7945
}
8046
else {
81-
using dpctl::tensor::type_utils::is_complex;
82-
if constexpr (is_complex<argT>::value) {
83-
std::uint8_t sgSize = ndit.get_sub_group().get_local_range()[0];
84-
size_t base = ndit.get_global_linear_id();
85-
86-
base = (base / sgSize) * sgSize * n_vecs * vec_sz +
87-
(base % sgSize);
88-
for (size_t offset = base;
89-
offset <
90-
std::min(nelems_, base + sgSize * (n_vecs * vec_sz));
91-
offset += sgSize)
92-
{
93-
out[offset] = std::abs(in[offset]);
94-
}
95-
}
96-
else {
97-
auto sg = ndit.get_sub_group();
98-
std::uint8_t sgSize = sg.get_local_range()[0];
99-
std::uint8_t maxsgSize = sg.get_max_local_range()[0];
100-
size_t base = n_vecs * vec_sz *
101-
(ndit.get_group(0) * ndit.get_local_range(0) +
102-
sg.get_group_id()[0] * maxsgSize);
103-
104-
if (base + n_vecs * vec_sz < nelems_) {
105-
using in_ptrT = sycl::multi_ptr<
106-
const argT, sycl::access::address_space::global_space>;
107-
using out_ptrT = sycl::multi_ptr<
108-
resT, sycl::access::address_space::global_space>;
109-
sycl::vec<argT, vec_sz> arg_vec;
110-
111-
#pragma unroll
112-
for (std::uint8_t it = 0; it < n_vecs * vec_sz;
113-
it += vec_sz) {
114-
arg_vec =
115-
sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
116-
#pragma unroll
117-
for (std::uint8_t k = 0; k < vec_sz; ++k) {
118-
arg_vec[k] = std::abs(arg_vec[k]);
119-
}
120-
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
121-
arg_vec);
122-
}
123-
}
124-
else {
125-
for (size_t k = base + sg.get_local_id()[0]; k < nelems_;
126-
k += sgSize) {
127-
out[k] = std::abs(in[k]);
128-
}
129-
}
130-
}
47+
return std::abs(x);
13148
}
13249
}
13350
};
13451

52+
template <typename argT,
53+
typename resT = argT,
54+
unsigned int vec_sz = 4,
55+
unsigned int n_vecs = 2>
56+
using AbsContigFunctor = elementwise_common::
57+
UnaryContigFunctor<argT, resT, AbsFunctor<argT, resT>, vec_sz, n_vecs>;
58+
13559
template <typename T> struct AbsOutputType
13660
{
13761
using value_type = typename std::disjunction< // disjunction is C++17
@@ -220,39 +144,9 @@ template <typename fnT, typename T> struct AbsTypeMapFactory
220144
}
221145
};
222146

223-
template <typename argT, typename resT, typename IndexerT>
224-
struct AbsStridedFunctor
225-
{
226-
private:
227-
const argT *in = nullptr;
228-
resT *out = nullptr;
229-
IndexerT inp_res_indexer_;
230-
231-
public:
232-
AbsStridedFunctor(const argT *inp_p,
233-
resT *res_p,
234-
IndexerT two_offsets_indexer)
235-
: in(inp_p), out(res_p), inp_res_indexer_(two_offsets_indexer)
236-
{
237-
}
238-
239-
void operator()(sycl::id<1> wid) const
240-
{
241-
auto offsets_ = inp_res_indexer_(static_cast<py::ssize_t>(wid[0]));
242-
const auto &inp_offset = offsets_.get_first_offset();
243-
const auto &out_offset = offsets_.get_second_offset();
244-
245-
if constexpr (std::is_same_v<argT, bool> ||
246-
(std::is_integral<argT>::value &&
247-
std::is_unsigned<argT>::value))
248-
{
249-
out[out_offset] = in[inp_offset];
250-
}
251-
else {
252-
out[out_offset] = std::abs(in[inp_offset]);
253-
}
254-
}
255-
};
147+
template <typename argTy, typename resTy, typename IndexerT>
148+
using AbsStridedFunctor = elementwise_common::
149+
UnaryStridedFunctor<argTy, resTy, IndexerT, AbsFunctor<argTy, resTy>>;
256150

257151
template <typename T1, typename T2, typename T3> class abs_strided_kernel;
258152

0 commit comments

Comments
 (0)