14
14
#include < executorch/kernels/portable/cpu/util/dtype_util.h>
15
15
#include < executorch/runtime/kernel/kernel_runtime_context.h>
16
16
17
+ #include < array>
18
+ #include < utility>
19
+
17
20
namespace torch {
18
21
namespace executor {
19
22
namespace native {
@@ -46,38 +49,94 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
46
49
: s.to <int64_t >();
47
50
}
48
51
49
- template <typename CTYPE_COMMON, const char * op_name, typename Op>
50
- inline void apply_unitensor_elementwise_fn (
52
+ namespace internal {
53
+ template <
54
+ typename CTYPE_COMMON,
55
+ const char * op_name,
56
+ typename Op,
57
+ typename ... Args>
58
+ inline void apply_elementwise_fn (
51
59
const Op& compute_fun,
52
60
KernelRuntimeContext& ctx,
53
- const Tensor& a,
54
- SupportedTensorDtypes a_dtypes,
55
61
const Tensor& out,
56
- SupportedTensorDtypes out_dtypes) {
62
+ SupportedTensorDtypes out_dtypes,
63
+ Args... inputs) {
64
+ static_assert (
65
+ (std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
66
+ ...));
67
+ constexpr auto kNumInputs = sizeof ...(inputs);
57
68
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
58
-
69
+ const auto check_input_dtype = [](auto input, auto compute_type) {
70
+ return internal::check_tensor_dtype (
71
+ *input.first , input.second , compute_type);
72
+ };
59
73
ET_KERNEL_CHECK (
60
74
ctx,
61
- (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
62
- internal::check_tensor_dtype (out, out_dtypes, compute_type) ),
75
+ (check_input_dtype (inputs, compute_type) && ... ) &&
76
+ internal::check_tensor_dtype (out, out_dtypes, compute_type),
63
77
InvalidArgument, );
64
78
65
- const auto load_a_to_common =
66
- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
79
+ bool any_is_broadcasted = false ;
80
+ if constexpr (kNumInputs > 1 ) {
81
+ any_is_broadcasted = (!out.sizes ().equals (inputs.first ->sizes ()) || ...);
82
+ }
83
+
84
+ struct InputInfo {
85
+ load_to_common_fn<CTYPE_COMMON> load_to_common;
86
+ const char * data_ptr;
87
+ ssize_t element_size;
88
+ };
89
+ std::array<InputInfo, kNumInputs > inputs_info = {(InputInfo{
90
+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(
91
+ *inputs.first , inputs.second ),
92
+ reinterpret_cast <const char *>(inputs.first ->const_data_ptr ()),
93
+ inputs.first ->element_size (),
94
+ })...};
95
+
67
96
const auto store_common_to_out =
68
97
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
69
98
out, out_dtypes);
70
- const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
71
- const auto a_element_size = a.element_size ();
72
- const auto out_element_size = out.element_size ();
73
99
char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
100
+ const auto out_element_size = out.element_size ();
74
101
75
- auto out_numel = out.numel ();
76
- for (const auto i : c10::irange (out_numel)) {
77
- auto result = compute_fun (load_a_to_common (&data_a[i * a_element_size]));
78
- store_common_to_out (result, &data_out[i * out_element_size]);
102
+ if (any_is_broadcasted) {
103
+ for (const auto & indexes :
104
+ BroadcastIndexesRange<kNumInputs >(out, (*inputs.first )...)) {
105
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
106
+ for (const auto idx : c10::irange (kNumInputs )) {
107
+ const auto & input_info = inputs_info[idx];
108
+ loaded_inputs[idx] = input_info.load_to_common (
109
+ &input_info.data_ptr [indexes[idx + 1 ] * input_info.element_size ]);
110
+ }
111
+ auto result = std::apply (compute_fun, loaded_inputs);
112
+ store_common_to_out (result, &data_out[indexes[0 ] * out_element_size]);
113
+ }
114
+ } else {
115
+ for (const auto i : c10::irange (out.numel ())) {
116
+ std::array<CTYPE_COMMON, kNumInputs > loaded_inputs;
117
+ for (const auto idx : c10::irange (kNumInputs )) {
118
+ const auto & input_info = inputs_info[idx];
119
+ loaded_inputs[idx] = input_info.load_to_common (
120
+ &input_info.data_ptr [i * input_info.element_size ]);
121
+ }
122
+ auto result = std::apply (compute_fun, loaded_inputs);
123
+ store_common_to_out (result, &data_out[i * out_element_size]);
124
+ }
79
125
}
80
126
}
127
+ } // namespace internal
128
+
129
+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
130
+ inline void apply_unitensor_elementwise_fn (
131
+ const Op& compute_fun,
132
+ KernelRuntimeContext& ctx,
133
+ const Tensor& a,
134
+ SupportedTensorDtypes a_dtypes,
135
+ const Tensor& out,
136
+ SupportedTensorDtypes out_dtypes) {
137
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
138
+ compute_fun, ctx, out, out_dtypes, std::make_pair (&a, a_dtypes));
139
+ }
81
140
82
141
/* *
83
142
* Useful for bi-tensor elementwise operators. For each element of the inputs,
@@ -94,53 +153,13 @@ inline void apply_bitensor_elementwise_fn(
94
153
SupportedTensorDtypes b_dtypes,
95
154
const Tensor& out,
96
155
SupportedTensorDtypes out_dtypes) {
97
- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
98
-
99
- ET_KERNEL_CHECK (
156
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
157
+ compute_fun,
100
158
ctx,
101
- (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
102
- internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
103
- internal::check_tensor_dtype (out, out_dtypes, compute_type)),
104
- InvalidArgument, );
105
-
106
- const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
107
- const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
108
- const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
109
-
110
- const auto load_a_to_common =
111
- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
112
- const auto load_b_to_common =
113
- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
114
- const auto store_common_to_out =
115
- internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
116
- out, out_dtypes);
117
- const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
118
- const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
119
- const auto a_element_size = a.element_size ();
120
- const auto b_element_size = b.element_size ();
121
- const auto out_element_size = out.element_size ();
122
- char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
123
-
124
- auto out_numel = out.numel ();
125
- if (any_is_broadcasted) {
126
- for (const auto [out_index, a_index, b_index] :
127
- BroadcastIndexesRange<2 >(out, a, b)) {
128
- auto result = compute_fun (
129
- load_a_to_common (&data_a[a_index * a_element_size]),
130
- load_b_to_common (&data_b[b_index * b_element_size]));
131
- store_common_to_out (result, &data_out[out_index * out_element_size]);
132
- }
133
- } else {
134
- for (const auto i : c10::irange (out_numel)) {
135
- size_t a_linear_index = i;
136
- size_t b_linear_index = i;
137
-
138
- auto result = compute_fun (
139
- load_a_to_common (&data_a[a_linear_index * a_element_size]),
140
- load_b_to_common (&data_b[b_linear_index * b_element_size]));
141
- store_common_to_out (result, &data_out[i * out_element_size]);
142
- }
143
- }
159
+ out,
160
+ out_dtypes,
161
+ std::make_pair (&a, a_dtypes),
162
+ std::make_pair (&b, b_dtypes));
144
163
}
145
164
146
165
/* *
@@ -175,63 +194,14 @@ inline void apply_tritensor_elementwise_fn(
175
194
SupportedTensorDtypes c_dtypes,
176
195
const Tensor& out,
177
196
SupportedTensorDtypes out_dtypes) {
178
- constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
179
-
180
- ET_KERNEL_CHECK (
197
+ internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
198
+ compute_fun,
181
199
ctx,
182
- (internal::check_tensor_dtype (a, a_dtypes, compute_type) &&
183
- internal::check_tensor_dtype (b, b_dtypes, compute_type) &&
184
- internal::check_tensor_dtype (c, c_dtypes, compute_type) &&
185
- internal::check_tensor_dtype (out, out_dtypes, compute_type)),
186
- InvalidArgument, );
187
-
188
- const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
189
- const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
190
- const bool c_is_broadcasted = !out.sizes ().equals (c.sizes ());
191
- const bool any_is_broadcasted =
192
- (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
193
-
194
- const auto load_a_to_common =
195
- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
196
- const auto load_b_to_common =
197
- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
198
- const auto load_c_to_common =
199
- internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
200
- const auto store_common_to_out =
201
- internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
202
- out, out_dtypes);
203
- const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
204
- const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
205
- const char * const data_c = reinterpret_cast <const char *>(c.const_data_ptr ());
206
- const auto a_element_size = a.element_size ();
207
- const auto b_element_size = b.element_size ();
208
- const auto c_element_size = c.element_size ();
209
- const auto out_element_size = out.element_size ();
210
- char * const data_out = reinterpret_cast <char *>(out.mutable_data_ptr ());
211
-
212
- auto out_numel = out.numel ();
213
- if (any_is_broadcasted) {
214
- for (const auto [out_index, a_index, b_index, c_index] :
215
- BroadcastIndexesRange<3 >(out, a, b, c)) {
216
- auto result = compute_fun (
217
- load_a_to_common (&data_a[a_index * a_element_size]),
218
- load_b_to_common (&data_b[b_index * b_element_size]),
219
- load_c_to_common (&data_c[c_index * c_element_size]));
220
- store_common_to_out (result, &data_out[out_index * out_element_size]);
221
- }
222
- } else {
223
- for (const auto i : c10::irange (out_numel)) {
224
- size_t a_linear_index = i;
225
- size_t b_linear_index = i;
226
- size_t c_linear_index = i;
227
-
228
- auto result = compute_fun (
229
- load_a_to_common (&data_a[a_linear_index * a_element_size]),
230
- load_b_to_common (&data_b[b_linear_index * b_element_size]),
231
- load_c_to_common (&data_c[c_linear_index * c_element_size]));
232
- store_common_to_out (result, &data_out[i * out_element_size]);
233
- }
234
- }
200
+ out,
201
+ out_dtypes,
202
+ std::make_pair (&a, a_dtypes),
203
+ std::make_pair (&b, b_dtypes),
204
+ std::make_pair (&c, c_dtypes));
235
205
}
236
206
237
207
inline ScalarType get_compute_type (ScalarType& common_type) {
0 commit comments