Skip to content

Commit 545de34

Browse files
ndgrigorianvtavana
authored andcommitted
Rewrites proj to use a custom implementation
- Avoids bug on CPU that results in incorrect result
1 parent 8004537 commit 545de34

File tree

3 files changed

+12
-36
lines changed

3 files changed

+12
-36
lines changed

dpctl/tensor/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,12 @@
100100
equal,
101101
exp,
102102
expm1,
103-
imag,
104103
isfinite,
105104
isinf,
106105
isnan,
107106
log,
108107
log1p,
109108
multiply,
110-
proj,
111-
real,
112109
sin,
113110
sqrt,
114111
)
@@ -194,14 +191,11 @@
194191
"cos",
195192
"exp",
196193
"expm1",
197-
"imag",
198194
"isinf",
199195
"isnan",
200196
"isfinite",
201197
"log",
202198
"log1p",
203-
"proj",
204-
"real",
205199
"sin",
206200
"sqrt",
207201
"divide",

dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
//===---------------------------------------------------------------------===//
2121
///
2222
/// \file
23-
/// This file defines kernels for elementwise evaluation of CONJ(x) function.
23+
/// This file defines kernels for elementwise evaluation of PROJ(x) function.
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
@@ -29,6 +29,7 @@
2929
#include <complex>
3030
#include <cstddef>
3131
#include <cstdint>
32+
#include <limits>
3233
#include <type_traits>
3334

3435
#include "kernels/elementwise_functions/common.hpp"
@@ -62,12 +63,19 @@ template <typename argT, typename resT> struct ProjFunctor
6263
// is function defined for sycl::vec
6364
using supports_vec = typename std::false_type;
6465
// do both argTy and resTy support sugroup store/load operation
65-
using supports_sg_loadstore = typename std::negation<
66-
std::disjunction<is_complex<resT>, is_complex<argT>>>;
66+
using supports_sg_loadstore = typename std::false_type;
6767

6868
resT operator()(const argT &in)
6969
{
70-
return std::proj(in);
70+
using realT = typename argT::value_type;
71+
const realT x = std::real(in);
72+
const realT y = std::imag(in);
73+
74+
if (std::isinf(x) || std::isinf(y)) {
75+
const realT res_im = std::copysign(0.0, y);
76+
return resT{std::numeric_limits<realT>::infinity(), res_im};
77+
}
78+
return in;
7179
}
7280
};
7381

@@ -86,8 +94,6 @@ template <typename T> struct ProjOutputType
8694
{
8795
using value_type = typename std::disjunction< // disjunction is C++17
8896
// feature, supported by DPC++
89-
td_ns::TypeMapResultEntry<T, float, std::complex<float>>,
90-
td_ns::TypeMapResultEntry<T, double, std::complex<double>>,
9197
td_ns::TypeMapResultEntry<T, std::complex<float>>,
9298
td_ns::TypeMapResultEntry<T, std::complex<double>>,
9399
td_ns::DefaultResultEntry<void>>::result_type;

dpctl/tensor/libtensor/source/elementwise_functions.cpp

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,39 +37,15 @@
3737
#include "kernels/elementwise_functions/conj.hpp"
3838
#include "kernels/elementwise_functions/cos.hpp"
3939
#include "kernels/elementwise_functions/equal.hpp"
40-
#include "kernels/elementwise_functions/isfinite.hpp"
41-
#include "kernels/elementwise_functions/isinf.hpp"
42-
#include "kernels/elementwise_functions/isnan.hpp"
43-
<<<<<<< HEAD
44-
=======
45-
<<<<<<< HEAD
46-
#include "kernels/elementwise_functions/log.hpp"
47-
#include "kernels/elementwise_functions/log1p.hpp"
48-
#include "kernels/elementwise_functions/multiply.hpp"
49-
#include "kernels/elementwise_functions/sin.hpp"
50-
=======
51-
<<<<<<< HEAD
52-
=======
53-
=======
54-
#include "kernels/elementwise_functions/expm1.hpp"
55-
=======
56-
>>>>>>> 91c53d433... address reviewer comments
5740
#include "kernels/elementwise_functions/exp.hpp"
5841
#include "kernels/elementwise_functions/expm1.hpp"
59-
#include "kernels/elementwise_functions/imag.hpp"
6042
#include "kernels/elementwise_functions/isfinite.hpp"
6143
#include "kernels/elementwise_functions/isinf.hpp"
6244
#include "kernels/elementwise_functions/isnan.hpp"
6345
#include "kernels/elementwise_functions/log.hpp"
6446
#include "kernels/elementwise_functions/log1p.hpp"
6547
#include "kernels/elementwise_functions/multiply.hpp"
66-
#include "kernels/elementwise_functions/proj.hpp"
67-
#include "kernels/elementwise_functions/real.hpp"
6848
#include "kernels/elementwise_functions/sin.hpp"
69-
>>>>>>> 94523d1ea... impl elementwise exp and sin
70-
>>>>>>> 653e3b5a1... impl_real_imag_conj
71-
>>>>>>> 74875eced... impl_real_imag_conj
72-
>>>>>>> f12aeb6f1... impl_real_imag_conj
7349
#include "kernels/elementwise_functions/sqrt.hpp"
7450
#include "kernels/elementwise_functions/true_divide.hpp"
7551

0 commit comments

Comments
 (0)