Skip to content

Commit ed93e02

Browse files
authored
Fix expm1 for complex special cases (#1332)
* Fixed some complex special cases for expm1 * Silence remaining warnings in elementwise tests * Tweaks to expm1 complex special case logic * expm1 special case change For `inf` real part and finite, nonzero imaginary part, now guaranteed to be (+/-`inf`, +/-`inf`), with cosine and sine of the imaginary part determining the sign, respectively
1 parent 597cd4b commit ed93e02

File tree

3 files changed

+115
-29
lines changed

3 files changed

+115
-29
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <cmath>
2929
#include <cstddef>
3030
#include <cstdint>
31+
#include <limits>
3132
#include <type_traits>
3233

3334
#include "kernels/elementwise_functions/common.hpp"
@@ -73,6 +74,46 @@ template <typename argT, typename resT> struct Expm1Functor
7374
const realT x = std::real(in);
7475
const realT y = std::imag(in);
7576

77+
// special cases
78+
if (std::isinf(x)) {
79+
if (x > realT(0)) {
80+
// positive infinity cases
81+
if (!std::isfinite(y)) {
82+
return resT{x, std::numeric_limits<realT>::quiet_NaN()};
83+
}
84+
else if (y == realT(0)) {
85+
return in;
86+
}
87+
else {
88+
return (resT{std::copysign(x, std::cos(y)),
89+
std::copysign(x, std::sin(y))});
90+
}
91+
}
92+
else {
93+
// negative infinity cases
94+
if (!std::isfinite(y)) {
95+
// copy sign of y to guarantee
96+
// conj(expm1(x)) == expm1(conj(x))
97+
return resT{realT(-1), std::copysign(realT(0), y)};
98+
}
99+
else {
100+
return resT{realT(-1),
101+
std::copysign(realT(0), std::sin(y))};
102+
}
103+
}
104+
}
105+
106+
if (std::isnan(x)) {
107+
if (y == realT(0)) {
108+
return in;
109+
}
110+
else {
111+
return resT{std::numeric_limits<realT>::quiet_NaN(),
112+
std::numeric_limits<realT>::quiet_NaN()};
113+
}
114+
}
115+
116+
// x, y finite numbers
76117
realT cosY_val;
77118
const realT sinY_val = sycl::sincos(y, &cosY_val);
78119
const realT sinhalfY_val = std::sin(y / 2);

dpctl/tests/elementwise/test_expm1.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,29 +116,53 @@ def test_expm1_order(dtype):
116116

117117

118118
def test_expm1_special_cases():
119-
q = get_queue_or_skip()
119+
get_queue_or_skip()
120120

121-
X = dpt.asarray(
122-
[dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q
123-
)
124-
Xnp = dpt.asnumpy(X)
121+
X = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4")
122+
res = np.asarray([np.nan, 0.0, -0.0, np.inf, -1.0], dtype="f4")
125123

126124
tol = dpt.finfo(X.dtype).resolution
127-
assert_allclose(
128-
dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol
129-
)
125+
assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol)
130126

131127
# special cases for complex variant
128+
num_finite = 1.0
132129
vals = [
133-
complex(*val)
134-
for val in itertools.permutations(
135-
[dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0], 2
136-
)
130+
complex(0.0, 0.0),
131+
complex(num_finite, dpt.inf),
132+
complex(num_finite, dpt.nan),
133+
complex(dpt.inf, 0.0),
134+
complex(-dpt.inf, num_finite),
135+
complex(dpt.inf, num_finite),
136+
complex(-dpt.inf, dpt.inf),
137+
complex(dpt.inf, dpt.inf),
138+
complex(-dpt.inf, dpt.nan),
139+
complex(dpt.inf, dpt.nan),
140+
complex(dpt.nan, 0.0),
141+
complex(dpt.nan, num_finite),
142+
complex(dpt.nan, dpt.nan),
137143
]
138144
X = dpt.asarray(vals, dtype=dpt.complex64)
139-
Xnp = dpt.asnumpy(X)
145+
cis_1 = complex(np.cos(num_finite), np.sin(num_finite))
146+
c_nan = complex(np.nan, np.nan)
147+
res = np.asarray(
148+
[
149+
complex(0.0, 0.0),
150+
c_nan,
151+
c_nan,
152+
complex(np.inf, 0.0),
153+
0.0 * cis_1 - 1.0,
154+
np.inf * cis_1 - 1.0,
155+
complex(-1.0, 0.0),
156+
complex(np.inf, np.nan),
157+
complex(-1.0, 0.0),
158+
complex(np.inf, np.nan),
159+
complex(np.nan, 0.0),
160+
c_nan,
161+
c_nan,
162+
],
163+
dtype=np.complex64,
164+
)
140165

141166
tol = dpt.finfo(X.dtype).resolution
142-
assert_allclose(
143-
dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol
144-
)
167+
with np.errstate(invalid="ignore"):
168+
assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol)

dpctl/tests/elementwise/test_log1p.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,28 +119,49 @@ def test_log1p_special_cases():
119119
q = get_queue_or_skip()
120120

121121
X = dpt.asarray(
122-
[dpt.nan, -1.0, -2.0, 0.0, -0.0, dpt.inf, -dpt.inf],
122+
[dpt.nan, -2.0, -1.0, -0.0, 0.0, dpt.inf],
123123
dtype="f4",
124124
sycl_queue=q,
125125
)
126-
Xnp = dpt.asnumpy(X)
126+
res = np.asarray([np.nan, np.nan, -np.inf, -0.0, 0.0, np.inf])
127127

128128
tol = dpt.finfo(X.dtype).resolution
129-
assert_allclose(
130-
dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol
131-
)
129+
with np.errstate(divide="ignore", invalid="ignore"):
130+
assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol)
132131

133132
# special cases for complex
134133
vals = [
135-
complex(*val)
136-
for val in itertools.permutations(
137-
[dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0, -1.0, -2.0], 2
138-
)
134+
complex(-1.0, 0.0),
135+
complex(2.0, dpt.inf),
136+
complex(2.0, dpt.nan),
137+
complex(-dpt.inf, 1.0),
138+
complex(dpt.inf, 1.0),
139+
complex(-dpt.inf, dpt.inf),
140+
complex(dpt.inf, dpt.inf),
141+
complex(dpt.inf, dpt.nan),
142+
complex(dpt.nan, 1.0),
143+
complex(dpt.nan, dpt.inf),
144+
complex(dpt.nan, dpt.nan),
139145
]
140146
X = dpt.asarray(vals, dtype=dpt.complex64)
141-
Xnp = dpt.asnumpy(X)
147+
c_nan = complex(np.nan, np.nan)
148+
res = np.asarray(
149+
[
150+
complex(-np.inf, 0.0),
151+
complex(np.inf, np.pi / 2),
152+
c_nan,
153+
complex(np.inf, np.pi),
154+
complex(np.inf, 0.0),
155+
complex(np.inf, 3 * np.pi / 4),
156+
complex(np.inf, np.pi / 4),
157+
complex(np.inf, np.nan),
158+
c_nan,
159+
complex(np.inf, np.nan),
160+
c_nan,
161+
],
162+
dtype=np.complex64,
163+
)
142164

143165
tol = dpt.finfo(X.dtype).resolution
144-
assert_allclose(
145-
dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol
146-
)
166+
with np.errstate(invalid="ignore"):
167+
assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)