Skip to content

[ExecuTorch] Add broadcast support for optimized add op #8205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
27a79c4
[Executorch] Refactor op_mul's broadcasting utils
kimishpatel Feb 5, 2025
dbe3e8a
[ExecuTorch] Add broadcast support for optimized add op
kimishpatel Feb 5, 2025
bf761db
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 6, 2025
0e1cfc7
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 6, 2025
0ce8fd7
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 6, 2025
00e54b8
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 7, 2025
7ea55eb
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 7, 2025
ffb6903
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 7, 2025
e9fe6af
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 7, 2025
e53eb97
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 11, 2025
a91eef8
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 11, 2025
f565c3b
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 11, 2025
656873f
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 11, 2025
8ecbd04
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 12, 2025
2804f70
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 12, 2025
f3406bf
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 12, 2025
132d2f5
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 12, 2025
216c4be
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 12, 2025
bde7998
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 12, 2025
110a932
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 13, 2025
7ebd165
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 13, 2025
5fb4107
Merge branch 'main' into gh/kimishpatel/154/head
kimishpatel Feb 13, 2025
9e0855b
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 13, 2025
0d19ade
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 13, 2025
8955d90
Update base for Update on "[ExecuTorch] Add broadcast support for opt…
kimishpatel Feb 15, 2025
6f2f01a
Update on "[ExecuTorch] Add broadcast support for optimized add op"
kimishpatel Feb 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion kernels/optimized/cpu/binary_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <executorch/kernels/optimized/vec/functional.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
Expand Down Expand Up @@ -235,7 +236,8 @@ Tensor& handle_broadcast_elementwise(
const Tensor& a,
const Tensor& b,
Tensor& out,
const ElementwiseOptimizedPath selected_optimized_path) {
const ElementwiseOptimizedPath selected_optimized_path,
const executorch::aten::optional<Scalar>& alpha = {}) {
if ((selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDim) ||
(selected_optimized_path ==
Expand Down
63 changes: 32 additions & 31 deletions kernels/optimized/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,40 +140,41 @@ Tensor& opt_add_out(
out.numel());
});
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
const Tensor* lhs;
const Tensor* rhs;
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
lhs = &b;
rhs = &a;
} else {
// Catch failure to update logic when adding new broadcasting possibility.
ET_DCHECK(
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1d);
lhs = &a;
rhs = &b;
}
auto error = resize_tensor(out, lhs->sizes());
ET_KERNEL_CHECK_MSG(
ctx,
error == Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

ET_KERNEL_CHECK_MSG(
ctx,
utils::extract_scalar(alpha, &alpha_val),
InvalidArgument,
out,
"Failed to extract scalar alpha.");
using Vec = executorch::vec::Vectorized<CTYPE>;
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
out.mutable_data_ptr<CTYPE>(),
lhs->const_data_ptr<CTYPE>(),
rhs->const_data_ptr<CTYPE>(),
lhs->sizes()[lhs->dim() - 2],
lhs->sizes()[lhs->dim() - 1]);
Vec alpha_val_vec(alpha_val);
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
// Reason we swap out args here is because handle_broadcast_elementwise
// handles this selected_optimized_path option a bit differently.
// This should really be resolved in handle_broadcast_elementwise.
// However, the current blocker is that handle_broadcast_elementwise
// tries to be agnostic of op. This should be fixed, likely by moving
// lambda creation to handle_broadcast_elementwise and it be aware of
// which op is being executed.
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return y + alpha_val_vec * x;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
} else {
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return x + alpha_val_vec * y;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
}
});
} else {
ScalarType common_type =
Expand Down
136 changes: 136 additions & 0 deletions kernels/test/op_add_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,125 @@ class OpAddOutKernelTest : public OperatorTest {
// tests.
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125}));
}

template <ScalarType DTYPE>
void test_broadcast_3D() {
TensorFactory<DTYPE> tf_a;

Tensor a =
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor b = tf_a.make({2, 1, 3}, /*data=*/{2, 3, 4, 5, 6, 7});

// Destination for output of mul.
Tensor out =
tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor expected = tf_a.make(
{2, 2, 3}, /*data=*/{3, 5, 7, 6, 8, 10, 12, 14, 16, 15, 17, 19});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
expected = tf_a.make(
{2, 2, 3},
/*data=*/{3.5, 6, 8.5, 8, 10.5, 13, 15.5, 18, 20.5, 20, 22.5, 25});
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.5, out), expected);
}

template <ScalarType DTYPE>
void test_broadcast_4D() {
TensorFactory<DTYPE> tf_a;

Tensor a = tf_a.make(
{2, 2, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
Tensor b = tf_a.make(
{2, 1, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30});
Comment on lines +142 to +151
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would probably be more reviewable to fill these programmatically, such as with std::iota, but certainly not blocking


// Destination for output of mul.
Tensor out = tf_a.zeros({2, 2, 3, 5});
Tensor expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30,
17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45,
47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75,
62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90});
Comment on lines +157 to +160
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto programmatic fill


// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);

b = tf_a.make(
{2, 2, 1, 5}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19, 20});
out = tf_a.zeros({2, 2, 3, 5});
expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{2, 4, 6, 8, 10, 7, 9, 11, 13, 15, 12, 14, 16, 18, 20,
22, 24, 26, 28, 30, 27, 29, 31, 33, 35, 32, 34, 36, 38, 40,
42, 44, 46, 48, 50, 47, 49, 51, 53, 55, 52, 54, 56, 58, 60,
62, 64, 66, 68, 70, 67, 69, 71, 73, 75, 72, 74, 76, 78, 80});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
}

template <ScalarType DTYPE>
void test_broadcast_last_dim() {
TensorFactory<DTYPE> tf_a;

Tensor a =
tf_a.make({4, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
Tensor b = tf_a.make({4, 1}, /*data=*/{2, 3, 4, 5});

// Destination for output of mul.
Tensor out = tf_a.zeros({4, 3});
Tensor expected =
tf_a.make({4, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);

a = tf_a.make({2, 2, 3}, /*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
b = tf_a.make({2, 2, 1}, /*data=*/{2, 3, 4, 5});

// Destination for output of mul.
out = tf_a.zeros({2, 2, 3});
expected = tf_a.make(
{2, 2, 3}, /*data=*/{3, 4, 5, 7, 8, 9, 11, 12, 13, 15, 16, 17});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);

a = tf_a.make(
{2, 2, 3, 5},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60});
b = tf_a.make(
{2, 2, 3, 1},
/*data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});

// Destination for output of mul.
out = tf_a.zeros({2, 2, 3, 5});
expected = tf_a.make(
{2, 2, 3, 5},
/*data=*/{2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18,
20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 32, 33, 34, 35, 36,
38, 39, 40, 41, 42, 44, 45, 46, 47, 48, 50, 51, 52, 53, 54,
56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 68, 69, 70, 71, 72});

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
}
};

class OpAddScalarOutKernelTest : public OperatorTest {
Expand Down Expand Up @@ -371,6 +490,23 @@ TEST_F(OpAddOutKernelTest, BroadcastOneElementRank0Tensor) {
EXPECT_TENSOR_EQ(out, ret);
}

TEST_F(OpAddOutKernelTest, BroadcastNDTest) {
// Test 3D tensors
test_broadcast_3D<ScalarType::Float>();
test_broadcast_3D<ScalarType::Half>();
test_broadcast_3D<ScalarType::BFloat16>();

// Test 4D tensors
test_broadcast_4D<ScalarType::Float>();
test_broadcast_4D<ScalarType::Half>();
test_broadcast_4D<ScalarType::BFloat16>();

// Test broadcasting on the last dimension
test_broadcast_last_dim<ScalarType::Float>();
test_broadcast_last_dim<ScalarType::Half>();
test_broadcast_last_dim<ScalarType::BFloat16>();
}

//
// Death Tests
//
Expand Down
10 changes: 0 additions & 10 deletions kernels/test/op_mul_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,6 @@ TEST_F(OpMulOutTest, BroadcastA2BTest) {
test_broadcast_a2b<ScalarType::Int>();
test_broadcast_a2b<ScalarType::Half>();
test_broadcast_a2b<ScalarType::BFloat16>();

// Test 3D tensors
test_broadcast_3D<ScalarType::Float>();
test_broadcast_3D<ScalarType::Half>();
test_broadcast_3D<ScalarType::BFloat16>();

// Test 4D tensors
test_broadcast_4D<ScalarType::Float>();
test_broadcast_4D<ScalarType::Half>();
test_broadcast_4D<ScalarType::BFloat16>();
}

// Broadcast tensor a's size to tensor b's size
Expand Down
Loading