Skip to content

Commit 4c9a340

Browse files
ndgrigorianvtavana
authored andcommitted
Rewrites proj to use a custom implementation
- Avoids bug on CPU that results in incorrect result
1 parent 6b28198 commit 4c9a340

File tree

1 file changed

+12
-6
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+12
-6
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/proj.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
//===---------------------------------------------------------------------===//
2121
///
2222
/// \file
23-
/// This file defines kernels for elementwise evaluation of CONJ(x) function.
23+
/// This file defines kernels for elementwise evaluation of PROJ(x) function.
2424
//===---------------------------------------------------------------------===//
2525

2626
#pragma once
@@ -29,6 +29,7 @@
2929
#include <complex>
3030
#include <cstddef>
3131
#include <cstdint>
32+
#include <limits>
3233
#include <type_traits>
3334

3435
#include "kernels/elementwise_functions/common.hpp"
@@ -62,12 +63,19 @@ template <typename argT, typename resT> struct ProjFunctor
6263
// is function defined for sycl::vec
6364
using supports_vec = typename std::false_type;
6465
// do both argTy and resTy support sugroup store/load operation
65-
using supports_sg_loadstore = typename std::negation<
66-
std::disjunction<is_complex<resT>, is_complex<argT>>>;
66+
using supports_sg_loadstore = typename std::false_type;
6767

6868
resT operator()(const argT &in)
6969
{
70-
return std::proj(in);
70+
using realT = typename argT::value_type;
71+
const realT x = std::real(in);
72+
const realT y = std::imag(in);
73+
74+
if (std::isinf(x) || std::isinf(y)) {
75+
const realT res_im = std::copysign(0.0, y);
76+
return resT{std::numeric_limits<realT>::infinity(), res_im};
77+
}
78+
return in;
7179
}
7280
};
7381

@@ -86,8 +94,6 @@ template <typename T> struct ProjOutputType
8694
{
8795
using value_type = typename std::disjunction< // disjunction is C++17
8896
// feature, supported by DPC++
89-
td_ns::TypeMapResultEntry<T, float, std::complex<float>>,
90-
td_ns::TypeMapResultEntry<T, double, std::complex<double>>,
9197
td_ns::TypeMapResultEntry<T, std::complex<float>>,
9298
td_ns::TypeMapResultEntry<T, std::complex<double>>,
9399
td_ns::DefaultResultEntry<void>>::result_type;

0 commit comments

Comments
 (0)