Skip to content

Commit

Permalink
Complex support for expm1 (pytorch#96644)
Browse files Browse the repository at this point in the history
Fixes pytorch#92619

Pull Request resolved: pytorch#96644
Approved by: https://github.com/soulitzer
  • Loading branch information
yhl48 authored and pytorchmergebot committed Mar 24, 2023
1 parent 1b8b82f commit 6fcd671
Show file tree
Hide file tree
Showing 14 changed files with 75 additions and 20 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ template <> class Vectorized<c10::complex<double>> {
return scaled_values.exp();
}
Vectorized<c10::complex<double>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<double>> sin() const {
return map(std::sin);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ template <> class Vectorized<c10::complex<float>> {
return scaled_values.exp();
}
Vectorized<c10::complex<float>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<float>> sin() const {
return map(std::sin);
Expand Down
7 changes: 3 additions & 4 deletions aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ class Vectorized<ComplexDbl> {
Vectorized<ComplexDbl> exp2() const {
return map(exp2_impl);
}
Vectorized<ComplexDbl> expm1() const {
return map(std::expm1);
}

Vectorized<ComplexDbl> pow(const Vectorized<ComplexDbl>& exp) const {
__at_align__ ComplexDbl x_tmp[size()];
Expand Down Expand Up @@ -498,10 +501,6 @@ class Vectorized<ComplexDbl> {
TORCH_CHECK(false, "not supported for complex numbers");
}

Vectorized<ComplexDbl> expm1() const {
TORCH_CHECK(false, "not supported for complex numbers");
}

Vectorized<ComplexDbl> operator<(const Vectorized<ComplexDbl>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Expand Down
7 changes: 3 additions & 4 deletions aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,9 @@ class Vectorized<ComplexFlt> {
Vectorized<ComplexFlt> exp2() const {
return map(exp2_impl);
}
Vectorized<ComplexFlt> expm1() const {
return map(std::expm1);
}

Vectorized<ComplexFlt> eq(const Vectorized<ComplexFlt>& other) const {
auto ret = (*this == other);
Expand Down Expand Up @@ -575,10 +578,6 @@ class Vectorized<ComplexFlt> {
TORCH_CHECK(false,"not supported for complex numbers");
}

Vectorized<ComplexFlt> expm1() const {
TORCH_CHECK(false,"not supported for complex numbers");
}

Vectorized<ComplexFlt> operator<(const Vectorized<ComplexFlt>& other) const {
TORCH_CHECK(false, "not supported for complex numbers");
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ template <> class Vectorized<c10::complex<double>> {
return scaled_values.exp();
}
Vectorized<c10::complex<double>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<double>> sin() const {
return map(std::sin);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ template <> class Vectorized<c10::complex<float>> {
return scaled_values.exp();
}
Vectorized<c10::complex<float>> expm1() const {
AT_ERROR("not supported for complex numbers");
return map(std::expm1);
}
Vectorized<c10::complex<float>> sin() const {
return map(std::sin);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ IMPLEMENT_FLOAT_KERNEL(erfinv)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_COMPLEX_KERNEL(exp)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(expm1)
IMPLEMENT_COMPLEX_KERNEL(expm1)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
IMPLEMENT_FLOAT_KERNEL(floor)
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/ForeachUnaryOp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ STD_FUNCTOR(op_name, functor_name); \
OP_CUSTOM_FUNCTOR(function, op_name, functor_name); \

OP(floating_half_bfloat16, erfc, Erfc);
OP(floating_half_bfloat16, expm1, Expm1);
OP(floating_half, lgamma, Lgamma);
OP(floating_half_bfloat16, trunc, Truncf);
OP(floating_half_bfloat16, floor, Floor);
Expand All @@ -206,6 +205,7 @@ OP(floating_complex_half_bfloat16, sin, Sin);
OP(floating_complex_half_bfloat16, sinh, Sinh);

OP(floating_complex_half_bfloat16, exp, Exp);
OP(floating_complex_half_bfloat16, expm1, Expm1);
OP(floating_complex_half_bfloat16, tanh, Tanh);
OP(floating_complex_half_bfloat16, log, Log);
OP(floating_complex_half_bfloat16, log10, Log10);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void exp_kernel_cuda(TensorIteratorBase& iter) {
}

void expm1_kernel_cuda(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
ScalarType::BFloat16, ScalarType::Half,
iter.common_dtype(), "expm1_cuda",
[&]() {
Expand Down
35 changes: 35 additions & 0 deletions c10/test/util/complex_math_test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,41 @@ C10_DEFINE_TEST(TestExponential, EulerFormula) {
}
}

C10_DEFINE_TEST(TestExpm1, Normal) {
// expm1(x) = exp(x) - 1
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l1 = std::expm1(x);
c10::complex<float> l2 = std::exp(x) - 1.0f;
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l1 = std::expm1(x);
c10::complex<double> l2 = std::exp(x) - 1.0;
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
}
}

C10_DEFINE_TEST(TestExpm1, Small) {
// expm1(x) = exp(x) - 1
// expm1(x) provides greater precision than exp(x) - 1 for small values of x
{
c10::complex<float> x(1e-30, 1e-30);
c10::complex<float> l1 = std::expm1(x);
C10_ASSERT_NEAR(l1.real(), 1e-30, tol);
C10_ASSERT_NEAR(l1.imag(), 1e-30, tol);
}
{
c10::complex<double> x(1e-100, 1e-100);
c10::complex<double> l1 = std::expm1(x);
C10_ASSERT_NEAR(l1.real(), 1e-30, tol);
C10_ASSERT_NEAR(l1.imag(), 1e-30, tol);
}
}

C10_DEFINE_TEST(TestLog, Definition) {
// log(x) = log(r) + i*theta
{
Expand Down
19 changes: 19 additions & 0 deletions c10/util/complex_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,23 @@ C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
}
}

template <typename T>
C10_HOST_DEVICE inline c10::complex<T> expm1(const c10::complex<T>& z) {
// expm1(z) = exp(z) - 1
// Define z = x + i * y
// f = e ^ (x + i * y) - 1
// = e ^ x * e ^ (i * y) - 1
// = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y))
// = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y)
// = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y)
T x = z.real();
T y = z.imag();
T a = std::sin(y / 2);
T er = std::expm1(x) * std::cos(y) - T(2) * a * a;
T ei = std::exp(x) * std::sin(y);
return {er, ei};
}

} // namespace c10_complex_math

using c10_complex_math::acos;
Expand All @@ -329,6 +346,7 @@ using c10_complex_math::atanh;
using c10_complex_math::cos;
using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::expm1;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;
Expand All @@ -351,6 +369,7 @@ using c10_complex_math::atanh;
using c10_complex_math::cos;
using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::expm1;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@
result: auto_element_wise

- name: expm1(Tensor self) -> Tensor
self: grad * (result + 1)
self: grad * (result.conj() + 1)
result: auto_element_wise

# TODO: this derivative is not SymInt safe, need sum_to support
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@
"fill_",
"exp",
"exp2",
"expm1",
"nonzero",
"mean",
"std_mean",
Expand Down
10 changes: 6 additions & 4 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8324,8 +8324,8 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs):

ForeachFuncInfo(
'expm1',
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
supports_autograd=True,
),
Expand Down Expand Up @@ -14456,8 +14456,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
UnaryUfuncInfo('expm1',
aliases=('special.expm1', ),
ref=np_unary_ufunc_integer_promotion_wrapper(np.expm1),
dtypes=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_sparse=True,
Expand All @@ -14472,6 +14472,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
device_type='cpu', dtypes=[torch.bfloat16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
device_type='cpu', dtypes=[torch.bfloat16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
device_type='cuda', dtypes=[torch.complex128]),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small',
device_type='cpu', dtypes=[torch.bfloat16]),
DecorateInfo(unittest.skip("Skipped! sparse backward not supported"),
Expand Down

0 comments on commit 6fcd671

Please sign in to comment.