Skip to content

Commit c73a526

Browse files
swolchokpytorchmergebot
authored andcommitted
Extract reusable portions of elu_kernel into header (#149673)
Similar to #140425, we are making the implementation usable via header-only code sharing. Review note: #62546 by @yanbing-j removed expm1 usage from this path. I don't know why and expm1 should be more efficient, so I've put it back. Please let me know if there is a good reason I shouldn't. Testing: existing correctness tests should cover. Pull Request resolved: #149673 Approved by: https://github.com/cyyever, https://github.com/Skylion007
1 parent b238e36 commit c73a526

File tree

4 files changed

+84
-53
lines changed

4 files changed

+84
-53
lines changed

aten/src/ATen/native/cpu/Activation.cpp

+5-43
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <ATen/cpu/vec/functional.h>
1616
#include <ATen/cpu/vec/vec.h>
1717
#include <ATen/native/TensorIterator.h>
18+
#include <ATen/native/cpu/Elu.h>
1819
#include <ATen/native/cpu/Gelu.h>
1920
#include <ATen/native/cpu/Loops.h>
2021
#include <ATen/Parallel.h>
@@ -190,56 +191,17 @@ static void threshold_kernel(
190191
void elu_kernel(TensorIteratorBase& it, const Scalar& alpha, const Scalar& scale, const Scalar& input_scale) {
191192
if (at::isReducedFloatingType(it.common_dtype())) {
192193
AT_DISPATCH_REDUCED_FLOATING_TYPES(it.common_dtype(), "elu_cpu", [&]() {
193-
auto negcoef = alpha.to<float>() * scale.to<float>();
194-
auto poscoef = scale.to<float>();
195-
auto negiptcoef = input_scale.to<float>();
196-
const Vectorized<float> negcoef_vec(negcoef);
197-
const Vectorized<float> negiptcoef_vec(negiptcoef);
198-
const Vectorized<float> poscoef_vec(poscoef);
199-
const Vectorized<float> one_vec(static_cast<float>(1));
200-
const Vectorized<float> zero_vec(static_cast<float>(0));
201194
cpu_kernel_vec(
202195
it,
203-
[negcoef, negiptcoef, poscoef](scalar_t a) -> scalar_t {
204-
return float(a) <= float(0) ? (std::exp(float(a) * negiptcoef) - float(1)) * negcoef : float(a) * poscoef;
205-
},
206-
[&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &one_vec, &zero_vec](Vectorized<scalar_t> a) -> Vectorized<scalar_t> {
207-
auto [a0, a1] = convert_to_float<scalar_t>(a);
208-
auto cmp0 = (a0 > zero_vec);
209-
auto cmp1 = (a1 > zero_vec);
210-
auto get_res_masked = [&](Vectorized<float>& cmp, Vectorized<float>& a) {
211-
return !cmp.zero_mask() ? a * poscoef_vec :
212-
Vectorized<float>::blendv(((a * negiptcoef_vec).exp() - one_vec) * negcoef_vec, a * poscoef_vec, cmp);
213-
};
214-
auto res0 = get_res_masked(cmp0, a0);
215-
auto res1 = get_res_masked(cmp1, a1);
216-
return convert_from_float<scalar_t>(res0, res1);
217-
});
196+
get_scalar_elu_elementwise_func<scalar_t, float>(alpha.to<float>(), scale.to<float>(), input_scale.to<float>()),
197+
get_vectorized_elu_elementwise_func<scalar_t>(alpha.to<float>(), scale.to<float>(), input_scale.to<float>()));
218198
});
219199
} else {
220200
AT_DISPATCH_FLOATING_TYPES(it.common_dtype(), "elu_cpu", [&]() {
221-
using Vec = Vectorized<scalar_t>;
222-
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
223-
auto poscoef = scale.to<scalar_t>();
224-
auto negiptcoef = input_scale.to<scalar_t>();
225-
const Vec negcoef_vec(negcoef);
226-
const Vec negiptcoef_vec(negiptcoef);
227-
const Vec poscoef_vec(poscoef);
228-
const Vec one_vec(static_cast<scalar_t>(1));
229-
const Vec zero_vec(static_cast<scalar_t>(0));
230201
cpu_kernel_vec(
231202
it,
232-
[negcoef, negiptcoef, poscoef](scalar_t a) -> scalar_t {
233-
return a <= scalar_t(0) ? (std::exp(a * negiptcoef) - scalar_t(1)) * negcoef : a * poscoef;
234-
},
235-
[&negcoef_vec, &negiptcoef_vec, &poscoef_vec, &one_vec, &zero_vec](Vec a) -> Vec {
236-
auto cmp = (a > zero_vec);
237-
if (!cmp.zero_mask()) { // only a * poscoef (which is very quick) needs to be computed
238-
return a * poscoef_vec;
239-
} else {
240-
return Vec::blendv(((a * negiptcoef_vec).exp() - one_vec) * negcoef_vec, a * poscoef_vec, cmp);
241-
}
242-
});
203+
get_scalar_elu_elementwise_func<scalar_t>(alpha.to<scalar_t>(), scale.to<scalar_t>(), input_scale.to<scalar_t>()),
204+
get_vectorized_elu_elementwise_func<scalar_t>(alpha.to<scalar_t>(), scale.to<scalar_t>(), input_scale.to<scalar_t>()));
243205
});
244206
}
245207
}

aten/src/ATen/native/cpu/Elu.h

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
// On Windows, math.h needs to be included with _USE_MATH_DEFINES defined to
4+
// access constants such as M_SQRT2 and M_2_SQRTPI.
5+
#ifdef _WIN32
6+
#define _USE_MATH_DEFINES
7+
#include <cmath>
8+
#endif // _WIN32
9+
10+
#include <ATen/cpu/vec/vec.h>
11+
#include <c10/util/BFloat16.h> // For c10::is_reduced_floating_point_v.
12+
13+
namespace at::native {
14+
/**
15+
* Return a function object that calculates ELU with the given
16+
* parameters on its input element. ParamT is the type of the input
17+
* and output to the ELU, and MathT is the type (possibly
18+
* higher-precision, e.g. float if ParamT is reduced-precision float)
19+
* in which to do intermediate calculations.
20+
*/
21+
template <typename ParamT, typename MathT=ParamT>
22+
auto get_scalar_elu_elementwise_func(MathT alpha, MathT scale, MathT input_scale) {
23+
const auto negcoef = alpha * scale;
24+
const auto poscoef = scale;
25+
const auto negiptcoef = input_scale;
26+
return [negcoef, negiptcoef, poscoef](ParamT a) -> ParamT {
27+
return MathT(a) <= MathT(0)
28+
? std::expm1(MathT(a) * negiptcoef) * negcoef
29+
: MathT(a) * poscoef;
30+
};
31+
}
32+
33+
/**
34+
* Return a function object that calculates ELU with the given
35+
* parameters on its input element. The function object takes and
36+
* returns Vectorized<T>.
37+
*/
38+
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
39+
auto get_vectorized_elu_elementwise_func(T alpha, T scale, T input_scale) {
40+
const vec::Vectorized<T> negcoef_vec(alpha * scale);
41+
const vec::Vectorized<T> poscoef_vec(scale);
42+
const vec::Vectorized<T> negiptcoef_vec(input_scale);
43+
const vec::Vectorized<T> zero_vec(static_cast<T>(0));
44+
return [negcoef_vec, poscoef_vec, negiptcoef_vec, zero_vec](vec::Vectorized<T> a) -> vec::Vectorized<T> {
45+
const auto cmp = a > zero_vec;
46+
if (!cmp.zero_mask()) {
47+
return a * poscoef_vec;
48+
} else {
49+
return vec::Vectorized<T>::blendv((a * negiptcoef_vec).expm1() * negcoef_vec, a * poscoef_vec, cmp);
50+
}
51+
};
52+
}
53+
54+
/**
55+
* Return a function object that calculates ELU with the given
56+
* parameters on its input element. The function object takes and
57+
* returns Vectorized<ParamT>, and Vectorized<MathT> is the type
58+
* (possibly higher-precision) in which to do intermediate
59+
* calculations.
60+
*/
61+
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
62+
auto get_vectorized_elu_elementwise_func(float alpha, float scale, float input_scale) {
63+
// Takes float->float.
64+
const auto float_func = get_vectorized_elu_elementwise_func<float>(alpha, scale, input_scale);
65+
return [float_func](vec::Vectorized<T> a) -> vec::Vectorized<T> {
66+
auto [a0, a1] = vec::convert_to_float<T>(a);
67+
auto res0 = float_func(a0);
68+
auto res1 = float_func(a1);
69+
return vec::convert_from_float<T>(res0, res1);
70+
};
71+
}
72+
} // namespace at::native

test/cpp/api/functional.cpp

+4-6
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ TEST_F(FunctionalTest, ELU) {
10631063
x_bf16.resize_({size, size, size});
10641064

10651065
auto y_exp = torch::max(torch::zeros_like(x), x) +
1066-
torch::min(torch::zeros_like(x), alpha * (torch::exp(x) - 1.0));
1066+
torch::min(torch::zeros_like(x), alpha * (torch::expm1(x)));
10671067
auto y = F::elu(x, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
10681068
auto y_bf16 =
10691069
F::elu(x_bf16, F::ELUFuncOptions().alpha(alpha).inplace(inplace));
@@ -1090,8 +1090,7 @@ TEST_F(FunctionalTest, SELU) {
10901090
auto input_bf16 = input.clone().to(torch::kBFloat16);
10911091
auto expected = scale *
10921092
(torch::max(torch::zeros_like(input), input) +
1093-
torch::min(
1094-
torch::zeros_like(input), alpha * (torch::exp(input) - 1)));
1093+
torch::min(torch::zeros_like(input), alpha * (torch::expm1(input))));
10951094
auto output = F::selu(input, inplace);
10961095
auto output_bf16 = F::selu(input_bf16, inplace);
10971096

@@ -1711,8 +1710,7 @@ TEST_F(FunctionalTest, CELU) {
17111710
x.resize_({size, size, size});
17121711
auto x_bf16 = x.clone().to(torch::kBFloat16);
17131712
auto y_exp = torch::max(torch::zeros_like(x), x) +
1714-
torch::min(torch::zeros_like(x),
1715-
alpha * (torch::exp(x / alpha) - 1.0));
1713+
torch::min(torch::zeros_like(x), alpha * (torch::expm1(x / alpha)));
17161714
auto y = F::celu(x, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
17171715
auto y_bf16 =
17181716
F::celu(x_bf16, F::CELUFuncOptions().alpha(alpha).inplace(inplace));
@@ -1737,7 +1735,7 @@ TEST_F(FunctionalTest, CELUDefaultOptions) {
17371735
x.resize_({size, size, size});
17381736
auto x_bf16 = x.clone().to(torch::kBFloat16);
17391737
auto y_exp = torch::max(torch::zeros_like(x), x) +
1740-
torch::min(torch::zeros_like(x), alpha * (torch::exp(x / alpha) - 1.0));
1738+
torch::min(torch::zeros_like(x), alpha * (torch::expm1(x / alpha)));
17411739
auto y = F::celu(x);
17421740
auto y_bf16 = F::celu(x_bf16);
17431741

test/cpp/api/modules.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -2432,8 +2432,7 @@ TEST_F(ModulesTest, ELU) {
24322432
ASSERT_EQ(y.ndimension(), 3);
24332433
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
24342434
auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
2435-
torch::min(torch::zeros_like(x_orig),
2436-
alpha * (torch::exp(x_orig) - 1.0));
2435+
torch::min(torch::zeros_like(x_orig), alpha * (torch::expm1(x_orig)));
24372436
ASSERT_TRUE(torch::allclose(y, y_exp));
24382437
if (inplace) {
24392438
ASSERT_TRUE(torch::allclose(x, y_exp));
@@ -2458,7 +2457,7 @@ TEST_F(ModulesTest, SELU) {
24582457
auto zero = torch::zeros_like(input);
24592458
auto expected = scale *
24602459
(torch::max(zero, input_orig) +
2461-
torch::min(zero, alpha * (torch::exp(input_orig) - 1)));
2460+
torch::min(zero, alpha * (torch::expm1(input_orig))));
24622461
auto s = output.sum();
24632462

24642463
ASSERT_EQ(s.ndimension(), 0);
@@ -2848,7 +2847,7 @@ TEST_F(ModulesTest, CELU) {
28482847
ASSERT_EQ(y.sizes(), std::vector<int64_t>({size, size, size}));
28492848
auto y_exp = torch::max(torch::zeros_like(x_orig), x_orig) +
28502849
torch::min(torch::zeros_like(x_orig),
2851-
alpha * (torch::exp(x_orig / alpha) - 1.0));
2850+
alpha * (torch::expm1(x_orig / alpha)));
28522851
ASSERT_TRUE(torch::allclose(y, y_exp));
28532852
if (inplace) {
28542853
ASSERT_TRUE(torch::allclose(x, y_exp));

0 commit comments

Comments
 (0)