Skip to content

Commit f91ebe2

Browse files
authored
Use template instead of duplicating code in elementwise_util.h (#9058)
Now all the apply functions share a common implementation, which means further changes (e.g., parallel_for, generating specialized dtypes for the case where all inputs have the same type) don't need to be repeated 3 times. (Interestingly, this seems to increase the effectiveness of the following parallelization change. Not entirely sure why, but I checked the generated code for optimized op_where and it seems to have improved, which is surprising.)
1 parent b2badda commit f91ebe2

File tree

1 file changed

+89
-119
lines changed

1 file changed

+89
-119
lines changed

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 89 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
1515
#include <executorch/runtime/kernel/kernel_runtime_context.h>
1616

17+
#include <array>
18+
#include <utility>
19+
1720
namespace torch {
1821
namespace executor {
1922
namespace native {
@@ -46,38 +49,94 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
4649
: s.to<int64_t>();
4750
}
4851

49-
template <typename CTYPE_COMMON, const char* op_name, typename Op>
50-
inline void apply_unitensor_elementwise_fn(
52+
namespace internal {
53+
template <
54+
typename CTYPE_COMMON,
55+
const char* op_name,
56+
typename Op,
57+
typename... Args>
58+
inline void apply_elementwise_fn(
5159
const Op& compute_fun,
5260
KernelRuntimeContext& ctx,
53-
const Tensor& a,
54-
SupportedTensorDtypes a_dtypes,
5561
const Tensor& out,
56-
SupportedTensorDtypes out_dtypes) {
62+
SupportedTensorDtypes out_dtypes,
63+
Args... inputs) {
64+
static_assert(
65+
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
66+
...));
67+
constexpr auto kNumInputs = sizeof...(inputs);
5768
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
58-
69+
const auto check_input_dtype = [](auto input, auto compute_type) {
70+
return internal::check_tensor_dtype(
71+
*input.first, input.second, compute_type);
72+
};
5973
ET_KERNEL_CHECK(
6074
ctx,
61-
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
62-
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
75+
(check_input_dtype(inputs, compute_type) && ...) &&
76+
internal::check_tensor_dtype(out, out_dtypes, compute_type),
6377
InvalidArgument, );
6478

65-
const auto load_a_to_common =
66-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
79+
bool any_is_broadcasted = false;
80+
if constexpr (kNumInputs > 1) {
81+
any_is_broadcasted = (!out.sizes().equals(inputs.first->sizes()) || ...);
82+
}
83+
84+
struct InputInfo {
85+
load_to_common_fn<CTYPE_COMMON> load_to_common;
86+
const char* data_ptr;
87+
ssize_t element_size;
88+
};
89+
std::array<InputInfo, kNumInputs> inputs_info = {(InputInfo{
90+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(
91+
*inputs.first, inputs.second),
92+
reinterpret_cast<const char*>(inputs.first->const_data_ptr()),
93+
inputs.first->element_size(),
94+
})...};
95+
6796
const auto store_common_to_out =
6897
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
6998
out, out_dtypes);
70-
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
71-
const auto a_element_size = a.element_size();
72-
const auto out_element_size = out.element_size();
7399
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
100+
const auto out_element_size = out.element_size();
74101

75-
auto out_numel = out.numel();
76-
for (const auto i : c10::irange(out_numel)) {
77-
auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size]));
78-
store_common_to_out(result, &data_out[i * out_element_size]);
102+
if (any_is_broadcasted) {
103+
for (const auto& indexes :
104+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
105+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
106+
for (const auto idx : c10::irange(kNumInputs)) {
107+
const auto& input_info = inputs_info[idx];
108+
loaded_inputs[idx] = input_info.load_to_common(
109+
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
110+
}
111+
auto result = std::apply(compute_fun, loaded_inputs);
112+
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
113+
}
114+
} else {
115+
for (const auto i : c10::irange(out.numel())) {
116+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
117+
for (const auto idx : c10::irange(kNumInputs)) {
118+
const auto& input_info = inputs_info[idx];
119+
loaded_inputs[idx] = input_info.load_to_common(
120+
&input_info.data_ptr[i * input_info.element_size]);
121+
}
122+
auto result = std::apply(compute_fun, loaded_inputs);
123+
store_common_to_out(result, &data_out[i * out_element_size]);
124+
}
79125
}
80126
}
127+
} // namespace internal
128+
129+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
130+
inline void apply_unitensor_elementwise_fn(
131+
const Op& compute_fun,
132+
KernelRuntimeContext& ctx,
133+
const Tensor& a,
134+
SupportedTensorDtypes a_dtypes,
135+
const Tensor& out,
136+
SupportedTensorDtypes out_dtypes) {
137+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
138+
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
139+
}
81140

82141
/**
83142
* Useful for bi-tensor elementwise operators. For each element of the inputs,
@@ -94,53 +153,13 @@ inline void apply_bitensor_elementwise_fn(
94153
SupportedTensorDtypes b_dtypes,
95154
const Tensor& out,
96155
SupportedTensorDtypes out_dtypes) {
97-
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
98-
99-
ET_KERNEL_CHECK(
156+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
157+
compute_fun,
100158
ctx,
101-
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
102-
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
103-
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
104-
InvalidArgument, );
105-
106-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
107-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
108-
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
109-
110-
const auto load_a_to_common =
111-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
112-
const auto load_b_to_common =
113-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
114-
const auto store_common_to_out =
115-
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
116-
out, out_dtypes);
117-
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
118-
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
119-
const auto a_element_size = a.element_size();
120-
const auto b_element_size = b.element_size();
121-
const auto out_element_size = out.element_size();
122-
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
123-
124-
auto out_numel = out.numel();
125-
if (any_is_broadcasted) {
126-
for (const auto [out_index, a_index, b_index] :
127-
BroadcastIndexesRange<2>(out, a, b)) {
128-
auto result = compute_fun(
129-
load_a_to_common(&data_a[a_index * a_element_size]),
130-
load_b_to_common(&data_b[b_index * b_element_size]));
131-
store_common_to_out(result, &data_out[out_index * out_element_size]);
132-
}
133-
} else {
134-
for (const auto i : c10::irange(out_numel)) {
135-
size_t a_linear_index = i;
136-
size_t b_linear_index = i;
137-
138-
auto result = compute_fun(
139-
load_a_to_common(&data_a[a_linear_index * a_element_size]),
140-
load_b_to_common(&data_b[b_linear_index * b_element_size]));
141-
store_common_to_out(result, &data_out[i * out_element_size]);
142-
}
143-
}
159+
out,
160+
out_dtypes,
161+
std::make_pair(&a, a_dtypes),
162+
std::make_pair(&b, b_dtypes));
144163
}
145164

146165
/**
@@ -175,63 +194,14 @@ inline void apply_tritensor_elementwise_fn(
175194
SupportedTensorDtypes c_dtypes,
176195
const Tensor& out,
177196
SupportedTensorDtypes out_dtypes) {
178-
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;
179-
180-
ET_KERNEL_CHECK(
197+
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
198+
compute_fun,
181199
ctx,
182-
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
183-
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
184-
internal::check_tensor_dtype(c, c_dtypes, compute_type) &&
185-
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
186-
InvalidArgument, );
187-
188-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
189-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
190-
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
191-
const bool any_is_broadcasted =
192-
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
193-
194-
const auto load_a_to_common =
195-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
196-
const auto load_b_to_common =
197-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
198-
const auto load_c_to_common =
199-
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
200-
const auto store_common_to_out =
201-
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
202-
out, out_dtypes);
203-
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
204-
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
205-
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
206-
const auto a_element_size = a.element_size();
207-
const auto b_element_size = b.element_size();
208-
const auto c_element_size = c.element_size();
209-
const auto out_element_size = out.element_size();
210-
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
211-
212-
auto out_numel = out.numel();
213-
if (any_is_broadcasted) {
214-
for (const auto [out_index, a_index, b_index, c_index] :
215-
BroadcastIndexesRange<3>(out, a, b, c)) {
216-
auto result = compute_fun(
217-
load_a_to_common(&data_a[a_index * a_element_size]),
218-
load_b_to_common(&data_b[b_index * b_element_size]),
219-
load_c_to_common(&data_c[c_index * c_element_size]));
220-
store_common_to_out(result, &data_out[out_index * out_element_size]);
221-
}
222-
} else {
223-
for (const auto i : c10::irange(out_numel)) {
224-
size_t a_linear_index = i;
225-
size_t b_linear_index = i;
226-
size_t c_linear_index = i;
227-
228-
auto result = compute_fun(
229-
load_a_to_common(&data_a[a_linear_index * a_element_size]),
230-
load_b_to_common(&data_b[b_linear_index * b_element_size]),
231-
load_c_to_common(&data_c[c_linear_index * c_element_size]));
232-
store_common_to_out(result, &data_out[i * out_element_size]);
233-
}
234-
}
200+
out,
201+
out_dtypes,
202+
std::make_pair(&a, a_dtypes),
203+
std::make_pair(&b, b_dtypes),
204+
std::make_pair(&c, c_dtypes));
235205
}
236206

237207
inline ScalarType get_compute_type(ScalarType& common_type) {

0 commit comments

Comments
 (0)