Skip to content

Commit 6fc46a6

Browse files
ckmadhirakirklandsign
authored andcommitted
Added hardtanh operator
Differential Revision: D71816454 Pull Request resolved: #9574
1 parent 2b2ae7c commit 6fc46a6

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

backends/cadence/aot/functions_fusion_g3.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@
171171
kernels:
172172
- arg_meta: null
173173
kernel_name: cadence::impl::G3::exp_out
174+
175+
- op: hardtanh.out
176+
kernels:
177+
- arg_meta: null
178+
kernel_name: cadence::impl::G3::hardtanh_out
174179

175180
# custom ops
176181
- func: cadence::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)

backends/cadence/fusion_g3/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ set(_aten_ops__srcs
5050
"${CMAKE_CURRENT_SOURCE_DIR}/op_lt.cpp"
5151
"${CMAKE_CURRENT_SOURCE_DIR}/op_where.cpp"
5252
"${CMAKE_CURRENT_SOURCE_DIR}/op_clamp.cpp"
53+
"${CMAKE_CURRENT_SOURCE_DIR}/op_hardtanh.cpp"
5354
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
5455
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
5556
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp"
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <executorch/backends/cadence/fusion_g3/operators/operators.h>
9+
10+
#include <cmath>
11+
12+
#include <xa_nnlib_kernels_api.h>
13+
14+
#include <executorch/backends/cadence/fusion_g3/operators/xt_macros.h>
15+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
16+
#include <executorch/kernels/portable/cpu/util/functional_util.h>
17+
#include <executorch/kernels/portable/cpu/util/math_util.h>
18+
#include <executorch/runtime/kernel/kernel_includes.h>
19+
20+
using ::executorch::aten::Scalar;
21+
using ::executorch::aten::ScalarType;
22+
using ::executorch::aten::Tensor;
23+
using ::executorch::runtime::Error;
24+
using ::executorch::runtime::KernelRuntimeContext;
25+
using ::torch::executor::native::utils::extract_scalar;
26+
using ::torch::executor::native::utils::get_scalar_dtype;
27+
28+
namespace cadence {
29+
namespace impl {
30+
namespace G3 {
31+
namespace native {
32+
33+
Tensor& hardtanh_out(
34+
KernelRuntimeContext& ctx,
35+
const Tensor& in,
36+
const Scalar& min,
37+
const Scalar& max,
38+
Tensor& out) {
39+
(void)ctx;
40+
41+
#ifdef OP_ARG_CHECK
42+
// Resize for dynamic shape
43+
ET_KERNEL_CHECK_MSG(
44+
ctx,
45+
executorch::runtime::resize_tensor(out, in.sizes()) == Error::Ok,
46+
InvalidArgument,
47+
out,
48+
"Failed to resize output tensor.");
49+
50+
ET_KERNEL_CHECK(
51+
ctx,
52+
executorch::runtime::tensors_have_same_dim_order(in, out),
53+
InvalidArgument,
54+
out);
55+
#endif
56+
57+
ScalarType in_type = in.scalar_type();
58+
ScalarType min_type = get_scalar_dtype(min);
59+
ScalarType max_type = get_scalar_dtype(max);
60+
ScalarType out_type = out.scalar_type();
61+
62+
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
63+
64+
if (in_type == ScalarType::Float) {
65+
const float* const inp1_data = in.const_data_ptr<float>();
66+
float* const out_data = out.mutable_data_ptr<float>();
67+
float min_val, max_val;
68+
extract_scalar(min, &min_val);
69+
extract_scalar(max, &max_val);
70+
71+
XT_KERNEL_CHECK(
72+
ctx,
73+
out,
74+
xa_nn_elm_clamp_scalar_f32_f32,
75+
out_data,
76+
inp1_data,
77+
min_val,
78+
max_val,
79+
out.numel());
80+
} else {
81+
ET_SWITCH_REALHBF16_TYPES(in_type, ctx, "hardtanh.out", CTYPE, [&]() {
82+
CTYPE min_casted;
83+
ET_SWITCH_SCALAR_OBJ_TYPES(
84+
min_type, ctx, "hardtanh.out", CTYPE_MIN, [&]() {
85+
CTYPE_MIN min_val;
86+
extract_scalar(min, &min_val);
87+
min_casted = static_cast<CTYPE>(min_val);
88+
});
89+
90+
CTYPE max_casted;
91+
ET_SWITCH_SCALAR_OBJ_TYPES(
92+
max_type, ctx, "hardtanh.out", CTYPE_MAX, [&]() {
93+
CTYPE_MAX max_val;
94+
extract_scalar(max, &max_val);
95+
max_casted = static_cast<CTYPE>(max_val);
96+
});
97+
98+
torch::executor::apply_unary_map_fn(
99+
[min_casted, max_casted](const CTYPE val_in) {
100+
return torch::executor::native::utils::min_override(
101+
torch::executor::native::utils::max_override(
102+
val_in, min_casted),
103+
max_casted);
104+
},
105+
in.const_data_ptr<CTYPE>(),
106+
out.mutable_data_ptr<CTYPE>(),
107+
in.numel());
108+
});
109+
}
110+
return out;
111+
}
112+
113+
} // namespace native
114+
} // namespace G3
115+
} // namespace impl
116+
} // namespace cadence

0 commit comments

Comments
 (0)