Skip to content

[Executorch] Refactor op_add to support op_sub broadcasting #8255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Feb 18, 2025

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Feb 6, 2025

Stack from ghstack (oldest at bottom):

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc @larryliu0820 @manuelcandales

Differential Revision: D69491817

Summary:
Refactoring broadcast handling utils that were added for op_mul. This is
in prepartion use these utils to handle broadcast for other ops such as
add, sub, div.
Plus remove a redundant test

Test Plan:
optimized_kernels_test in CI

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:
This brings add op to feature parity, wrt, broadcasting, to mul op in
optimized kernels lib

Test Plan:
tests added

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Feb 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8255

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 17c29b0 with merge base a01571f (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@kimishpatel kimishpatel requested a review from swolchok February 6, 2025 06:40
Copy link
Contributor

@swolchok swolchok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this stack is going to need some changes in response to comments on this diff and previous, pausing review here

out,
"Failed to resize output tensor.");

ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to put const char *op_name in the template parameters and fix this

Comment on lines 168 to 169
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto op name

@kimishpatel
Copy link
Contributor Author

looks like this stack is going to need some changes in response to comments on this diff and previous, pausing review here

yeah sounds good. Let me address your comments in the previous diffs

@kimishpatel kimishpatel added module: kernels Issues related to kernel libraries and utilities, and code under kernels/ release notes: ops & kernels Changes to the opset and any new / changed kernel implementations labels Feb 7, 2025
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
@kimishpatel kimishpatel requested a review from swolchok February 11, 2025 00:07
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

[ghstack-poisoned]
Comment on lines +129 to +176
if constexpr (is_sub) {
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return y - alpha_val_vec * x;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
} else {
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return x - alpha_val_vec * y;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
}
} else {
if (selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
selected_optimized_path ==
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
// Reason we swap out args here is because
// handle_broadcast_elementwise handles this selected_optimized_path
// option a bit differently. This should really be resolved in
// handle_broadcast_elementwise. However, the current blocker is that
// handle_broadcast_elementwise tries to be agnostic of op. This
// should be fixed, likely by moving lambda creation to
// handle_broadcast_elementwise and it be aware of which op is being
// executed.
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return y + alpha_val_vec * x;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
} else {
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
return x + alpha_val_vec * y;
};
return torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
}
}
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just select the lambdas based on is_sub rather than duplicating the rest of the code under this if constexpr: https://godbolt.org/z/Esdz1exKj

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats a good suggestion. I tried doing the same in a different way which didnt quite work, but i can try out your suggestion

@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
@kimishpatel kimishpatel changed the base branch from gh/kimishpatel/155/base to main February 15, 2025 04:30
…_sub broadcasting"

Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
Summary:
Refactor op_add to conslidate commong broadcasting related improvements

Test Plan:
Previously added tests

Reviewers:

Subscribers:

Tasks:

Tags:

cc larryliu0820 manuelcandales

Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kimishpatel kimishpatel merged commit 1b6c12d into main Feb 18, 2025
45 of 48 checks passed
@kimishpatel kimishpatel deleted the gh/kimishpatel/155/head branch February 18, 2025 21:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: kernels Issues related to kernel libraries and utilities, and code under kernels/ release notes: ops & kernels Changes to the opset and any new / changed kernel implementations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants