-
Notifications
You must be signed in to change notification settings - Fork 593
[Executorch] Refactor op_add to support op_sub broadcasting #8255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Refactoring broadcast handling utils that were added for op_mul. This is in prepartion use these utils to handle broadcast for other ops such as add, sub, div. Plus remove a redundant test Test Plan: optimized_kernels_test in CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: This brings add op to feature parity, wrt, broadcasting, to mul op in optimized kernels lib Test Plan: tests added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8255
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 17c29b0 with merge base a01571f ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this stack is going to need some changes in response to comments on this diff and previous, pausing review here
out, | ||
"Failed to resize output tensor."); | ||
|
||
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to put const char *op_name
in the template parameters and fix this
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { | ||
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto op name
yeah sounds good. Let me address your comments in the previous diffs |
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales [ghstack-poisoned]
if constexpr (is_sub) { | ||
if (selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return y - alpha_val_vec * x; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} else { | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return x - alpha_val_vec * y; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} | ||
} else { | ||
if (selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { | ||
// Reason we swap out args here is because | ||
// handle_broadcast_elementwise handles this selected_optimized_path | ||
// option a bit differently. This should really be resolved in | ||
// handle_broadcast_elementwise. However, the current blocker is that | ||
// handle_broadcast_elementwise tries to be agnostic of op. This | ||
// should be fixed, likely by moving lambda creation to | ||
// handle_broadcast_elementwise and it be aware of which op is being | ||
// executed. | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return y + alpha_val_vec * x; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} else { | ||
auto add_lambda = [&alpha_val_vec](auto x, auto y) { | ||
return x + alpha_val_vec * y; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<CTYPE>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} | ||
} | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can just select the lambdas based on is_sub rather than duplicating the rest of the code under this if constexpr
: https://godbolt.org/z/Esdz1exKj
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats a good suggestion. I tried doing the same in a different way which didnt quite work, but i can try out your suggestion
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
…_sub broadcasting" Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: cc larryliu0820 manuelcandales Differential Revision: [D69491817](https://our.internmc.facebook.com/intern/diff/D69491817) [ghstack-poisoned]
@kimishpatel has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Stack from ghstack (oldest at bottom):
Summary:
Refactor op_add to conslidate commong broadcasting related improvements
Test Plan:
Previously added tests
Reviewers:
Subscribers:
Tasks:
Tags:
cc @larryliu0820 @manuelcandales
Differential Revision: D69491817