Skip to content

[Executorch] Refactor op_add to support op_sub broadcasting #8255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
27a79c4
[Executorch] Refactor op_mul's broadcasting utils
kimishpatel Feb 5, 2025
dbe3e8a
[ExecuTorch] Add broadcast support for optimized add op
kimishpatel Feb 5, 2025
bf761db
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 6, 2025
fb13cd0
[Executorch] Refactor op_add to support op_sub broadcasting
kimishpatel Feb 6, 2025
e12eeb3
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 6, 2025
9c10c86
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 6, 2025
f15d962
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 7, 2025
5f67d96
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 7, 2025
4905ec7
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 7, 2025
fd28b4e
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 7, 2025
a5db857
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 11, 2025
b02b27d
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 11, 2025
06f4c7b
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 11, 2025
de24c75
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 11, 2025
f10c82c
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 12, 2025
43ba1fb
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 12, 2025
3e44312
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 12, 2025
bf27d12
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 12, 2025
e3527dd
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 12, 2025
4f81db5
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 12, 2025
5101f82
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 13, 2025
c680ad4
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 13, 2025
bf29af7
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 13, 2025
28b1a90
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 13, 2025
0de5f8b
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 15, 2025
f9773c8
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 15, 2025
9c774d1
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 15, 2025
c6ae86f
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 15, 2025
e5e2db2
Update base for Update on "[Executorch] Refactor op_add to support op…
kimishpatel Feb 18, 2025
17c29b0
Update on "[Executorch] Refactor op_add to support op_sub broadcasting"
kimishpatel Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 8 additions & 147 deletions kernels/optimized/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,59 +14,11 @@
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

#include <executorch/kernels/optimized/cpu/op_add_sub_impl.h>

namespace torch {
namespace executor {
namespace native {
namespace {

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted + alpha_val * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

template <typename CTYPE_IN>
struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug<CTYPE_IN> {};

} // namespace

using Tensor = executorch::aten::Tensor;
using ScalarType = executorch::aten::ScalarType;

Expand All @@ -76,8 +28,6 @@ Tensor& opt_add_out(
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
(void)ctx;

ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
ScalarType out_type = out.scalar_type();
Expand All @@ -95,7 +45,9 @@ Tensor& opt_add_out(
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
ctx,
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
InvalidArgument, );
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
CTYPE b_casted = static_cast<CTYPE>(b_val);

Expand All @@ -115,100 +67,9 @@ Tensor& opt_add_out(
return opt_add_out(ctx, b, a, alpha, out);
}

auto selected_optimized_path = select_optimized_path(a, b, out);
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
// Resize for dynamic shape
auto error = resize_tensor(out, a.sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");

ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::map2<CTYPE>(
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
out.mutable_data_ptr<CTYPE>(),
a.const_data_ptr<CTYPE>(),
b.const_data_ptr<CTYPE>(),
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK_MSG(
ctx,
utils::extract_scalar(alpha, &alpha_val),
InvalidArgument,
out,
"Failed to extract scalar alpha.");
using Vec = executorch::vec::Vectorized<CTYPE>;
Vec alpha_val_vec(alpha_val);
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
// Reason we swap out args here is because handle_broadcast_elementwise
// handles this selected_optimized_path option a bit differently.
// This should really be resolved in handle_broadcast_elementwise.
// However, the current blocker is that handle_broadcast_elementwise
// tries to be agnostic of op. This should be fixed, likely by moving
// lambda creation to handle_broadcast_elementwise and it be aware of
// which op is being executed.
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return y + alpha_val_vec * x;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
} else {
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return x + alpha_val_vec * y;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
}
});
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

AddInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, alpha_val, out);
});
});
});
}

return out;
static constexpr const char op_name[] = "add.out";
return torch::executor::kernels::impl::opt_add_sub_out_impl<false, op_name>(
ctx, a, b, alpha, out);
}

Tensor& opt_add_scalar_out(
Expand Down
Loading
Loading