Skip to content

Commit 038345c

Browse files
committed
Add optimized ELU implementation
This uses PyTorch code sharing, so we'll need a pin bump to pick up pytorch/pytorch#149673 (and pytorch/pytorch#149684 and, when it lands, pytorch/pytorch#149780). ghstack-source-id: 50f2d19171aa73280a6d68e297202d9ea421df8d ghstack-comment-id: 2744722101 Pull Request resolved: #9521
1 parent 1dcabde commit 038345c

File tree

5 files changed

+111
-1
lines changed

5 files changed

+111
-1
lines changed

kernels/optimized/cpu/op_elu.cpp

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/native/cpu/Elu.h>
10+
11+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
14+
#include <executorch/runtime/platform/assert.h>
15+
16+
namespace torch::executor::native {
17+
18+
namespace {
19+
template <typename CTYPE>
20+
void elu(
21+
KernelRuntimeContext& context,
22+
const Tensor& input,
23+
const Scalar& alpha,
24+
const Scalar& scale,
25+
const Scalar& input_scale,
26+
Tensor& out) {
27+
const CTYPE* in_data = input.const_data_ptr<CTYPE>();
28+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
29+
using MathT =
30+
std::conditional_t<c10::is_reduced_floating_point_v<CTYPE>, float, CTYPE>;
31+
MathT math_alpha = 0;
32+
MathT math_scale = 0;
33+
MathT math_input_scale = 0;
34+
ET_EXTRACT_SCALAR(alpha, math_alpha);
35+
ET_EXTRACT_SCALAR(scale, math_scale);
36+
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
37+
const auto scalar_func =
38+
at::native::get_scalar_elu_elementwise_func<CTYPE, MathT>(
39+
math_alpha, math_scale, math_input_scale);
40+
const auto vec_func = at::native::get_vectorized_elu_elementwise_func<CTYPE>(
41+
math_alpha, math_scale, math_input_scale);
42+
43+
::executorch::extension::parallel_for(
44+
0,
45+
out.numel(),
46+
::executorch::extension::internal::GRAIN_SIZE,
47+
[&](const auto begin, const auto end) {
48+
using Vec = at::vec::Vectorized<CTYPE>;
49+
const auto vectorized_begin =
50+
begin + (Vec::size() - begin % Vec::size()) % Vec::size();
51+
const auto vectorized_end = end - (end % Vec::size());
52+
// Scalar prologue.
53+
for (const auto idx : c10::irange(begin, vectorized_begin)) {
54+
out_data[idx] = scalar_func(in_data[idx]);
55+
}
56+
57+
// Main vectorized loop.
58+
for (auto idx = vectorized_begin; idx < vectorized_end;
59+
idx += Vec::size()) {
60+
auto result_vec = vec_func(Vec::loadu(&in_data[idx]));
61+
result_vec.store(&out_data[idx]);
62+
}
63+
64+
// Scalar epilogue.
65+
for (const auto idx : c10::irange(vectorized_end, end)) {
66+
out_data[idx] = scalar_func(in_data[idx]);
67+
}
68+
});
69+
}
70+
} // namespace
71+
72+
Tensor& opt_elu_out(
73+
KernelRuntimeContext& ctx,
74+
const Tensor& in,
75+
const Scalar& alpha,
76+
const Scalar& scale,
77+
const Scalar& input_scale,
78+
Tensor& out) {
79+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
80+
ET_KERNEL_CHECK(
81+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
82+
83+
ET_KERNEL_CHECK(
84+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
85+
86+
ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out);
87+
88+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
89+
90+
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, "elu.out", CTYPE, [&]() {
91+
elu<CTYPE>(ctx, in, alpha, scale, input_scale, out);
92+
});
93+
return out;
94+
}
95+
96+
} // namespace torch::executor::native

kernels/optimized/cpu/targets.bzl

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ _OPTIMIZED_ATEN_OPS = (
2525
"//executorch/kernels/portable/cpu/util:broadcast_util",
2626
],
2727
),
28+
op_target(
29+
name = "op_elu",
30+
deps = [
31+
"//executorch/extension/threadpool:threadpool",
32+
"//executorch/kernels/portable/cpu:scalar_utils",
33+
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
34+
],
35+
),
2836
op_target(name = "op_exp"),
2937
op_target(
3038
name = "op_fft_r2c",

kernels/optimized/optimized.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
- arg_meta: null
3838
kernel_name: torch::executor::opt_div_scalar_out
3939

40+
- op: elu.out
41+
kernels:
42+
- arg_meta: null
43+
kernel_name: torch::executor::opt_elu_out
44+
4045
- op: exp.out
4146
kernels:
4247
- arg_meta: null

kernels/test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ set(_optimized_kernels_test_sources
274274
"op_add_test.cpp"
275275
"op_bmm_test.cpp"
276276
"op_div_test.cpp"
277+
"op_elu_test.cpp"
277278
"op_exp_test.cpp"
278279
"op_fft_r2c_test.cpp"
279280
"op_gelu_test.cpp"

kernels/test/targets.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def define_common_targets():
215215
_common_op_test("op_detach_copy_test", ["aten", "portable"])
216216
_common_op_test("op_diagonal_copy_test", ["aten", "portable"])
217217
_common_op_test("op_div_test", ["aten", "portable", "optimized"])
218-
_common_op_test("op_elu_test", ["aten", "portable"])
218+
_common_op_test("op_elu_test", ["aten", "portable", "optimized"])
219219
_common_op_test("op_embedding_test", ["aten", "portable"])
220220
_common_op_test("op_empty_test", ["aten", "portable"])
221221
_common_op_test("op_eq_test", ["aten", "portable"])

0 commit comments

Comments
 (0)