Skip to content

Use template instead of duplicating code in elementwise_util.h #9058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
d0b11e8
Update
swolchok Mar 4, 2025
9437be1
Update
swolchok Mar 4, 2025
643e10e
Update
swolchok Mar 4, 2025
6f2842b
Update
swolchok Mar 4, 2025
e47dfeb
Update
swolchok Mar 4, 2025
231ebc3
Update
swolchok Mar 5, 2025
296513c
Update
swolchok Mar 5, 2025
845a01e
Update
swolchok Mar 5, 2025
a92958a
Update
swolchok Mar 5, 2025
3fa99d6
Update
swolchok Mar 5, 2025
a6c69a6
Update
swolchok Mar 5, 2025
3bd6437
Update
swolchok Mar 5, 2025
675f01b
Update
swolchok Mar 5, 2025
5f3a768
Update
swolchok Mar 5, 2025
9fdebee
Update
swolchok Mar 5, 2025
70a7096
Update
swolchok Mar 5, 2025
337dc23
Update
swolchok Mar 5, 2025
f388177
Update
swolchok Mar 5, 2025
2949daf
Update
swolchok Mar 5, 2025
7347915
Update
swolchok Mar 5, 2025
1a8481d
Update
swolchok Mar 5, 2025
e48e816
Update
swolchok Mar 5, 2025
3351d50
Update
swolchok Mar 6, 2025
0102e25
Update
swolchok Mar 6, 2025
a1aeae7
Update
swolchok Mar 6, 2025
c658163
Update
swolchok Mar 6, 2025
7e0ccd4
Update
swolchok Mar 6, 2025
d9cd27c
Update
swolchok Mar 6, 2025
c130224
Update
swolchok Mar 6, 2025
754a4f6
Update
swolchok Mar 6, 2025
11c5707
Update
swolchok Mar 6, 2025
7ca7627
Update
swolchok Mar 6, 2025
d428ca2
Update
swolchok Mar 6, 2025
b478275
Update
swolchok Mar 7, 2025
0470870
Update
swolchok Mar 7, 2025
5a283c8
Update
swolchok Mar 7, 2025
a8a0e57
Update
swolchok Mar 7, 2025
df93cd4
Update
swolchok Mar 7, 2025
6350e07
Update
swolchok Mar 7, 2025
bd20770
Update
swolchok Mar 7, 2025
e7190a8
Update
swolchok Mar 7, 2025
4dd58a0
Update
swolchok Mar 7, 2025
1b6eb9f
Update
swolchok Mar 7, 2025
450e50b
Update
swolchok Mar 7, 2025
fad4ed8
Update
swolchok Mar 7, 2025
085b624
Update
swolchok Mar 7, 2025
e4af3bb
Update
swolchok Mar 8, 2025
9c9e31e
Update
swolchok Mar 8, 2025
34423ae
Update
swolchok Mar 8, 2025
379c10e
Update
swolchok Mar 8, 2025
fb5e06c
Update
swolchok Mar 8, 2025
ef1a0ce
Update
swolchok Mar 10, 2025
6844013
Update
swolchok Mar 10, 2025
e417a3b
Update
swolchok Mar 10, 2025
adaae97
Update
swolchok Mar 10, 2025
ea335ee
Update
swolchok Mar 10, 2025
4a7ba26
Update
swolchok Mar 11, 2025
c0d1daa
Update
swolchok Mar 11, 2025
f408201
Update
swolchok Mar 11, 2025
e2fb689
Update
swolchok Mar 11, 2025
994c5f5
Update
swolchok Mar 11, 2025
4917358
Update
swolchok Mar 11, 2025
4a43b35
Update
swolchok Mar 11, 2025
3fe478d
Update
swolchok Mar 11, 2025
21d8aac
Update
swolchok Mar 11, 2025
2272c40
Update
swolchok Mar 11, 2025
73f37ee
Update
swolchok Mar 11, 2025
a8dd330
Update
swolchok Mar 11, 2025
0088cd2
Update
swolchok Mar 11, 2025
6b296df
Update
swolchok Mar 11, 2025
854c967
Update
swolchok Mar 11, 2025
5781018
Update
swolchok Mar 11, 2025
caac9df
Update
swolchok Mar 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions kernels/portable/cpu/util/broadcast_indexes_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <iterator>
#include <tuple>

#include <executorch/kernels/portable/cpu/util/delinearize_index.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_dimension_limit.h>

Expand Down Expand Up @@ -78,7 +79,9 @@ class BroadcastIndexesIterator {
// You might wonder what happens if output_shape_[ii] == 0. In
// that case, output.numel() would be 0, and thus we would have
// begin() == end() and no iteration.
if ET_UNLIKELY (delinearized_output_index_[ii] == output_shape_[ii] - 1) {
if ET_UNLIKELY (
static_cast<exec_aten::SizesType>(delinearized_output_index_[ii]) ==
output_shape_[ii] - 1) {
const auto old_delinearized_output_index_item =
delinearized_output_index_[ii];
delinearized_output_index_[ii] = 0;
Expand All @@ -104,11 +107,42 @@ class BroadcastIndexesIterator {
return it;
}

BroadcastIndexesIterator& operator+=(difference_type n) {
if (n <= 3) {
std::advance(*this, n);
return *this;
}

output_index() += n;
delinearize_index(
output_index(),
output_shape_,
delinearized_output_index_.data(),
delinearized_output_index_.size());
for (const auto ii : c10::irange(1, kNumInputs + 1)) {
current_indexes_[ii] = 0;
for (const auto jj : c10::irange(output_dim_)) {
current_indexes_[ii] += delinearized_output_index_[jj] *
effective_input_broadcast_strides_[ii - 1][jj];
}
}
return *this;
}

BroadcastIndexesIterator operator+(difference_type n) {
auto it = *this;
it += n;
return it;
}

difference_type operator-(const BroadcastIndexesIterator& rhs) const {
return difference_type(output_index() - rhs.output_index());
}

private:
using ShapeType =
std::array<std::size_t, executorch::runtime::kTensorDimensionLimit>;

ssize_t output_index() const {
return current_indexes_[0];
}
Expand All @@ -117,11 +151,10 @@ class BroadcastIndexesIterator {
return current_indexes_[0];
}

std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>
effective_input_broadcast_stride(const Tensor& output, const Tensor& t)
const {
std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>
result = {0};
ShapeType effective_input_broadcast_stride(
const Tensor& output,
const Tensor& t) const {
ShapeType result = {0};
ET_CHECK_MSG(
t.dim() <= output.dim(),
"input to broadcasting op should have dim at most output dim, but %d > %d!",
Expand All @@ -146,8 +179,6 @@ class BroadcastIndexesIterator {
// The 0th entry is the current linear index into the output,
// followed by kNumInputs input indexes.
std::array<ssize_t, kNumInputs + 1> current_indexes_ = {0};
using ShapeType = std::
array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>;
ShapeType delinearized_output_index_ = {0};
ssize_t output_dim_;
ArrayRef<exec_aten::SizesType> output_shape_;
Expand Down
208 changes: 89 additions & 119 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>

#include <array>
#include <utility>

namespace torch {
namespace executor {
namespace native {
Expand Down Expand Up @@ -46,38 +49,94 @@ inline int64_t scalar_to<int64_t>(const Scalar& s) {
: s.to<int64_t>();
}

template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_unitensor_elementwise_fn(
namespace internal {
template <
typename CTYPE_COMMON,
const char* op_name,
typename Op,
typename... Args>
inline void apply_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
SupportedTensorDtypes out_dtypes,
Args... inputs) {
static_assert(
(std::is_same_v<Args, std::pair<const Tensor*, SupportedTensorDtypes>> &&
...));
constexpr auto kNumInputs = sizeof...(inputs);
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;

const auto check_input_dtype = [](auto input, auto compute_type) {
return internal::check_tensor_dtype(
*input.first, input.second, compute_type);
};
ET_KERNEL_CHECK(
ctx,
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
(check_input_dtype(inputs, compute_type) && ...) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL fold expressions in cpp

internal::check_tensor_dtype(out, out_dtypes, compute_type),
InvalidArgument, );

const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
bool any_is_broadcasted = false;
if constexpr (kNumInputs > 1) {
any_is_broadcasted = (!out.sizes().equals(inputs.first->sizes()) || ...);
}

struct InputInfo {
load_to_common_fn<CTYPE_COMMON> load_to_common;
const char* data_ptr;
ssize_t element_size;
};
std::array<InputInfo, kNumInputs> inputs_info = {(InputInfo{
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(
*inputs.first, inputs.second),
reinterpret_cast<const char*>(inputs.first->const_data_ptr()),
inputs.first->element_size(),
})...};

const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
const auto a_element_size = a.element_size();
const auto out_element_size = out.element_size();
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
const auto out_element_size = out.element_size();

auto out_numel = out.numel();
for (const auto i : c10::irange(out_numel)) {
auto result = compute_fun(load_a_to_common(&data_a[i * a_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
if (any_is_broadcasted) {
for (const auto& indexes :
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
for (const auto idx : c10::irange(kNumInputs)) {
const auto& input_info = inputs_info[idx];
loaded_inputs[idx] = input_info.load_to_common(
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
}
auto result = std::apply(compute_fun, loaded_inputs);
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
for (const auto idx : c10::irange(kNumInputs)) {
const auto& input_info = inputs_info[idx];
loaded_inputs[idx] = input_info.load_to_common(
&input_info.data_ptr[i * input_info.element_size]);
}
auto result = std::apply(compute_fun, loaded_inputs);
store_common_to_out(result, &data_out[i * out_element_size]);
}
}
}
} // namespace internal

template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_unitensor_elementwise_fn(
const Op& compute_fun,
KernelRuntimeContext& ctx,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
compute_fun, ctx, out, out_dtypes, std::make_pair(&a, a_dtypes));
}

/**
* Useful for bi-tensor elementwise operators. For each element of the inputs,
Expand All @@ -94,53 +153,13 @@ inline void apply_bitensor_elementwise_fn(
SupportedTensorDtypes b_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;

ET_KERNEL_CHECK(
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
compute_fun,
ctx,
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
InvalidArgument, );

const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);

const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
const auto load_b_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
const auto a_element_size = a.element_size();
const auto b_element_size = b.element_size();
const auto out_element_size = out.element_size();
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index] :
BroadcastIndexesRange<2>(out, a, b)) {
auto result = compute_fun(
load_a_to_common(&data_a[a_index * a_element_size]),
load_b_to_common(&data_b[b_index * b_element_size]));
store_common_to_out(result, &data_out[out_index * out_element_size]);
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes));
}

/**
Expand Down Expand Up @@ -175,63 +194,14 @@ inline void apply_tritensor_elementwise_fn(
SupportedTensorDtypes c_dtypes,
const Tensor& out,
SupportedTensorDtypes out_dtypes) {
constexpr auto compute_type = CppTypeToScalarType<CTYPE_COMMON>::value;

ET_KERNEL_CHECK(
internal::apply_elementwise_fn<CTYPE_COMMON, op_name>(
compute_fun,
ctx,
(internal::check_tensor_dtype(a, a_dtypes, compute_type) &&
internal::check_tensor_dtype(b, b_dtypes, compute_type) &&
internal::check_tensor_dtype(c, c_dtypes, compute_type) &&
internal::check_tensor_dtype(out, out_dtypes, compute_type)),
InvalidArgument, );

const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
const bool any_is_broadcasted =
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);

const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
const auto load_b_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
const auto load_c_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
const auto a_element_size = a.element_size();
const auto b_element_size = b.element_size();
const auto c_element_size = c.element_size();
const auto out_element_size = out.element_size();
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index, c_index] :
BroadcastIndexesRange<3>(out, a, b, c)) {
auto result = compute_fun(
load_a_to_common(&data_a[a_index * a_element_size]),
load_b_to_common(&data_b[b_index * b_element_size]),
load_c_to_common(&data_c[c_index * c_element_size]));
store_common_to_out(result, &data_out[out_index * out_element_size]);
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}
out,
out_dtypes,
std::make_pair(&a, a_dtypes),
std::make_pair(&b, b_dtypes),
std::make_pair(&c, c_dtypes));
}

inline ScalarType get_compute_type(ScalarType& common_type) {
Expand Down
Loading
Loading