-
Notifications
You must be signed in to change notification settings - Fork 638
[ExecuTorch] Add broadcast support for optimized add op #8205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
27a79c4
dbe3e8a
bf761db
0e1cfc7
0ce8fd7
00e54b8
7ea55eb
ffb6903
e9fe6af
e53eb97
a91eef8
f565c3b
656873f
8ecbd04
2804f70
f3406bf
132d2f5
216c4be
bde7998
110a932
7ebd165
5fb4107
9e0855b
0d19ade
8955d90
6f2f01a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,8 @@ | |
|
||
#pragma once | ||
|
||
#include <executorch/kernels/optimized/vec/functional.h> | ||
#include <executorch/kernels/portable/cpu/scalar_utils.h> | ||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
|
||
namespace torch { | ||
|
@@ -47,8 +49,38 @@ enum class ElementwiseOptimizedPath { | |
kBroadcastLastDimReverseArguments, | ||
}; | ||
|
||
enum class BinaryOpType { | ||
kAdd, | ||
kSub, | ||
kMul, | ||
kDiv, | ||
}; | ||
|
||
namespace internal { | ||
|
||
template <BinaryOpType op_type> | ||
struct BinaryOpTypeName; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kAdd> { | ||
static constexpr char kName[] = "add.out"; | ||
}; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kSub> { | ||
static constexpr char kName[] = "sub.out"; | ||
}; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kMul> { | ||
static constexpr char kName[] = "mul.out"; | ||
}; | ||
|
||
template <> | ||
struct BinaryOpTypeName<BinaryOpType::kDiv> { | ||
static constexpr char kName[] = "div.out"; | ||
}; | ||
|
||
/* | ||
Given two tensors, this function returns the broadcast dim if it exists. | ||
Returns 0 if no broadcast dim is found. | ||
|
@@ -190,5 +222,145 @@ std::array<int32_t, 3> inline get_normalized_tensor_size( | |
return normalized_tensor_size; | ||
} | ||
|
||
template <BinaryOpType op_type, typename Op> | ||
Tensor& handle_last_dim_broadcast_elementwise( | ||
KernelRuntimeContext& ctx, | ||
const Op& vec_fun, | ||
const Tensor& a, | ||
const Tensor& b, | ||
Tensor& out, | ||
const ElementwiseOptimizedPath selected_optimized_path, | ||
executorch::aten::optional<Scalar>& alpha = {}) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. error messages are telling you this needs to be a const ref. also why is this not std::optional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah just realized that. not sure why it did not throw error in local build but different compile options i guess. I just followed what i see elsewhere. Happy to switch to std::optional too which is what I guess is backing that but maybe for aten build it aliases c10:::optional? Let me check that first There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. c10::optional is gone |
||
ScalarType out_type = out.scalar_type(); | ||
const Tensor* lhs; | ||
const Tensor* rhs; | ||
if (selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { | ||
lhs = &b; | ||
rhs = &a; | ||
} else { | ||
lhs = &a; | ||
rhs = &b; | ||
} | ||
auto error = resize_tensor(out, lhs->sizes()); | ||
ET_KERNEL_CHECK_MSG( | ||
ctx, | ||
error == Error::Ok, | ||
InvalidArgument, | ||
out, | ||
"Failed to resize output tensor."); | ||
const size_t outer_size = getLeadingDims(out, out.dim() - 1); | ||
const auto broadcast_size = out.size(out.dim() - 1); | ||
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() { | ||
using Vec = executorch::vec::Vectorized<CTYPE>; | ||
CTYPE alpha_val; | ||
Vec alpha_val_vec(alpha_val); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. normally I would say "alpha_val needs to be initialized; C++ doesn't have default zero-initialization for primitives", but actually the problem here is that
|
||
if (alpha.has_value()) { | ||
ET_KERNEL_CHECK( | ||
ctx, | ||
native::utils::extract_scalar(alpha.value(), &alpha_val), | ||
InvalidArgument, ); | ||
alpha_val_vec = Vec(alpha_val); | ||
} | ||
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) { | ||
return vec_fun(a, b, alpha_val_vec); | ||
}; | ||
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>( | ||
vec_fun_alpha, | ||
out.mutable_data_ptr<CTYPE>(), | ||
lhs->const_data_ptr<CTYPE>(), | ||
rhs->const_data_ptr<CTYPE>(), | ||
outer_size, | ||
broadcast_size); | ||
}); | ||
return out; | ||
} | ||
|
||
template <BinaryOpType op_type, typename Op> | ||
Tensor& handle_broadcast_elementwise( | ||
KernelRuntimeContext& ctx, | ||
const Op& vec_fun, | ||
const Tensor& a, | ||
const Tensor& b, | ||
Tensor& out, | ||
const ElementwiseOptimizedPath selected_optimized_path, | ||
executorch::aten::optional<Scalar> alpha = {}) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this by-value but the other one is a reference? make consistent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh good call out. my bad |
||
if ((selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDim) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { | ||
return handle_last_dim_broadcast_elementwise<op_type>( | ||
ctx, vec_fun, a, b, out, selected_optimized_path, alpha); | ||
} | ||
|
||
ScalarType out_type = out.scalar_type(); | ||
const Tensor* lhs; | ||
const Tensor* rhs; | ||
if ((selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
lhs = &b; | ||
rhs = &a; | ||
} else { | ||
// Catch failure to update logic when adding new broadcasting possibility. | ||
ET_DCHECK( | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1d) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNd)); | ||
lhs = &a; | ||
rhs = &b; | ||
} | ||
auto error = resize_tensor(out, lhs->sizes()); | ||
ET_KERNEL_CHECK_MSG( | ||
ctx, | ||
error == Error::Ok, | ||
InvalidArgument, | ||
out, | ||
"Failed to resize output tensor."); | ||
int64_t outer_size = 1; | ||
int64_t broadcast_size; | ||
int64_t inner_size; | ||
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); | ||
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; | ||
auto normalized_tensor_size_lhs = | ||
get_normalized_tensor_size(*lhs, broadcast_dim_lhs); | ||
outer_size = normalized_tensor_size_lhs[0]; | ||
broadcast_size = normalized_tensor_size_lhs[1]; | ||
inner_size = normalized_tensor_size_lhs[2]; | ||
} else { | ||
broadcast_size = lhs->sizes()[lhs->dim() - 2]; | ||
inner_size = lhs->sizes()[lhs->dim() - 1]; | ||
} | ||
ET_SWITCH_REALB_TYPES(out_type, ctx, internal::BinaryOpTypeName<op_type>::kName, CTYPE, [&]() { | ||
using Vec = executorch::vec::Vectorized<CTYPE>; | ||
CTYPE alpha_val; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same problem as above |
||
Vec alpha_val_vec; | ||
if (alpha.has_value()) { | ||
ET_KERNEL_CHECK( | ||
ctx, | ||
native::utils::extract_scalar(alpha.value(), &alpha_val), | ||
InvalidArgument, ); | ||
alpha_val_vec = Vec(alpha_val); | ||
} | ||
auto vec_fun_alpha = [vec_fun, alpha_val_vec](const Vec& a, const Vec& b) { | ||
return vec_fun(a, b, alpha_val_vec); | ||
}; | ||
executorch::vec:: | ||
broadcasting_map_3d_and_unsqueezed_3d<CTYPE, decltype(vec_fun_alpha)>( | ||
vec_fun_alpha, | ||
out.mutable_data_ptr<CTYPE>(), | ||
lhs->const_data_ptr<CTYPE>(), | ||
rhs->const_data_ptr<CTYPE>(), | ||
outer_size, | ||
broadcast_size, | ||
inner_size); | ||
}); | ||
return out; | ||
} | ||
} // namespace executor | ||
} // namespace torch |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,41 +140,31 @@ Tensor& opt_add_out( | |
out.numel()); | ||
}); | ||
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { | ||
const Tensor* lhs; | ||
const Tensor* rhs; | ||
if (selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) { | ||
lhs = &b; | ||
rhs = &a; | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments || | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) { | ||
// This behavior is a bit confusing. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what's confusing here; there is an argument that should be scaled by alpha_val, we have to scale the right one. definitely don't think handle_broadcast_elementwise should be coupled to the specific op. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. problem is this. All the reverse arg stuff has specifically different handlking inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess confusing is not the right word here though. |
||
// Reason we swap out args here is because handle_broadcast_elementwise | ||
// handles this selected_optimized_path option a bit differently. | ||
// This should really be resoled in handle_broadcast_elementwise. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/resoled/resolved/ |
||
// However, the current blocker is that handle_broadcast_elementwise tries to | ||
// be agnostic of op. This should be fixed, likely by moving lambda creation | ||
// to handle_broadcast_elementwise and it be aware of which op is being executed. | ||
auto add_lambda = [](auto x, auto y, auto alpha_val) { | ||
return y + alpha_val * x; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} else { | ||
// Catch failure to update logic when adding new broadcasting possibility. | ||
ET_DCHECK( | ||
selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1d); | ||
lhs = &a; | ||
rhs = &b; | ||
auto add_lambda = [](auto x, auto y, auto alpha_val) { | ||
return x + alpha_val * y; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>( | ||
ctx, add_lambda, a, b, out, selected_optimized_path, alpha); | ||
} | ||
auto error = resize_tensor(out, lhs->sizes()); | ||
ET_KERNEL_CHECK_MSG( | ||
ctx, | ||
error == Error::Ok, | ||
InvalidArgument, | ||
out, | ||
"Failed to resize output tensor."); | ||
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() { | ||
CTYPE alpha_val; | ||
ET_KERNEL_CHECK( | ||
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); | ||
|
||
using Vec = executorch::vec::Vectorized<CTYPE>; | ||
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>( | ||
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; }, | ||
out.mutable_data_ptr<CTYPE>(), | ||
lhs->const_data_ptr<CTYPE>(), | ||
rhs->const_data_ptr<CTYPE>(), | ||
lhs->sizes()[lhs->dim() - 2], | ||
lhs->sizes()[lhs->dim() - 1]); | ||
}); | ||
} else { | ||
ScalarType common_type = | ||
promoteTypes(a_type, b_type, /*half_to_float*/ true); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,114 +68,6 @@ template < | |
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> | ||
: public ReportCanCastBug {}; | ||
|
||
Tensor& handle_last_dim_broadcast( | ||
KernelRuntimeContext& ctx, | ||
const Tensor& a, | ||
const Tensor& b, | ||
Tensor& out, | ||
const ElementwiseOptimizedPath selected_optimized_path) { | ||
ScalarType out_type = out.scalar_type(); | ||
const Tensor* lhs; | ||
const Tensor* rhs; | ||
if (selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments) { | ||
lhs = &b; | ||
rhs = &a; | ||
} else { | ||
lhs = &a; | ||
rhs = &b; | ||
} | ||
auto error = resize_tensor(out, lhs->sizes()); | ||
ET_KERNEL_CHECK_MSG( | ||
ctx, | ||
error == Error::Ok, | ||
InvalidArgument, | ||
out, | ||
"Failed to resize output tensor."); | ||
const size_t outer_size = getLeadingDims(out, out.dim() - 1); | ||
const auto broadcast_size = out.size(out.dim() - 1); | ||
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
using Vec = executorch::vec::Vectorized<CTYPE>; | ||
executorch::vec::broadcasting_map_broadcast_last_dim<CTYPE>( | ||
[](Vec x, Vec y) { return x * y; }, | ||
out.mutable_data_ptr<CTYPE>(), | ||
lhs->const_data_ptr<CTYPE>(), | ||
rhs->const_data_ptr<CTYPE>(), | ||
outer_size, | ||
broadcast_size); | ||
}); | ||
return out; | ||
} | ||
|
||
Tensor& handle_broadcast_mul( | ||
KernelRuntimeContext& ctx, | ||
const Tensor& a, | ||
const Tensor& b, | ||
Tensor& out, | ||
const ElementwiseOptimizedPath selected_optimized_path) { | ||
if ((selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDim) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments)) { | ||
return handle_last_dim_broadcast(ctx, a, b, out, selected_optimized_path); | ||
} | ||
|
||
ScalarType out_type = out.scalar_type(); | ||
const Tensor* lhs; | ||
const Tensor* rhs; | ||
if ((selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
lhs = &b; | ||
rhs = &a; | ||
} else { | ||
// Catch failure to update logic when adding new broadcasting possibility. | ||
ET_DCHECK( | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcast2dBy1d) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNd)); | ||
lhs = &a; | ||
rhs = &b; | ||
} | ||
auto error = resize_tensor(out, lhs->sizes()); | ||
ET_KERNEL_CHECK_MSG( | ||
ctx, | ||
error == Error::Ok, | ||
InvalidArgument, | ||
out, | ||
"Failed to resize output tensor."); | ||
int64_t outer_size = 1; | ||
int64_t broadcast_size; | ||
int64_t inner_size; | ||
if ((selected_optimized_path == ElementwiseOptimizedPath::kBroadcastNdByNd) || | ||
(selected_optimized_path == | ||
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments)) { | ||
int32_t broadcast_dim = internal::get_broadcast_dim(*lhs, *rhs); | ||
int32_t broadcast_dim_lhs = lhs->dim() + broadcast_dim; | ||
auto normalized_tensor_size_lhs = | ||
get_normalized_tensor_size(*lhs, broadcast_dim_lhs); | ||
outer_size = normalized_tensor_size_lhs[0]; | ||
broadcast_size = normalized_tensor_size_lhs[1]; | ||
inner_size = normalized_tensor_size_lhs[2]; | ||
} else { | ||
broadcast_size = lhs->sizes()[lhs->dim() - 2]; | ||
inner_size = lhs->sizes()[lhs->dim() - 1]; | ||
} | ||
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() { | ||
using Vec = executorch::vec::Vectorized<CTYPE>; | ||
executorch::vec::broadcasting_map_3d_and_unsqueezed_3d<CTYPE>( | ||
[](Vec x, Vec y) { return x * y; }, | ||
out.mutable_data_ptr<CTYPE>(), | ||
lhs->const_data_ptr<CTYPE>(), | ||
rhs->const_data_ptr<CTYPE>(), | ||
outer_size, | ||
broadcast_size, | ||
inner_size); | ||
}); | ||
return out; | ||
} | ||
} // namespace | ||
|
||
Tensor& opt_mul_out( | ||
|
@@ -238,7 +130,13 @@ Tensor& opt_mul_out( | |
out.numel()); | ||
}); | ||
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { | ||
return handle_broadcast_mul(ctx, a, b, out, selected_optimized_path); | ||
// Reason for using alpha: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing rest of comment after the colon |
||
auto mul_lambda = [](auto x, auto y, auto alpha) { | ||
(void)alpha; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thank you :) |
||
return x * y; | ||
}; | ||
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kMul>( | ||
ctx, mul_lambda, a, b, out, selected_optimized_path); | ||
} else { | ||
ScalarType common_type = | ||
promoteTypes(a_type, b_type, /*half_to_float*/ true); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to do this. see existing example:
executorch/kernels/portable/cpu/op_rsub.cpp
Lines 50 to 55 in c82a7df
the secret sauce is that the string literal has to be a static constexpr const char [] and then you can pass it to a
const char*
template argument directly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I was hoping you would point me to something better for this