Skip to content

Use compile-time promotion to reduce optimized add/sub op size & build time #3533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions kernels/optimized/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,55 @@
namespace torch {
namespace executor {
namespace native {
namespace {

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted + alpha_val * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

template <typename CTYPE_IN>
struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug<CTYPE_IN> {};

} // namespace

using Tensor = exec_aten::Tensor;
using ScalarType = exec_aten::ScalarType;
Expand Down Expand Up @@ -69,26 +118,20 @@ Tensor& opt_add_out(

ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(common_type, ctx, "add.out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx,
utils::extract_scalar(alpha, &alpha_val),
InvalidArgument, );

apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted + alpha_val * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

AddInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, alpha_val, out);
});
});
});
Expand Down
73 changes: 59 additions & 14 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,50 @@ bool can_use_optimized_path(
(a.numel() == b.numel() && a.numel() == out.numel()));
return can_use_optimized_path;
}

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};
} // namespace

Tensor& opt_mul_out(
Expand Down Expand Up @@ -86,20 +130,21 @@ Tensor& opt_mul_out(

ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(common_type, ctx, "mul.out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
});
});
Expand Down
82 changes: 62 additions & 20 deletions kernels/optimized/cpu/op_sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,55 @@
namespace torch {
namespace executor {
namespace native {
namespace {

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct SubInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct SubInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted - alpha_val * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

template <typename CTYPE_IN>
struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct SubInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug<CTYPE_IN> {};

} // namespace

using Tensor = exec_aten::Tensor;
using ScalarType = exec_aten::ScalarType;
Expand Down Expand Up @@ -72,26 +121,19 @@ Tensor& opt_sub_out(

ET_SWITCH_REALH_TYPES(a_type, ctx, "sub.out", CTYPE_A, [&]() {
ET_SWITCH_REALH_TYPES(b_type, ctx, "sub.out", CTYPE_B, [&]() {
ET_SWITCH_REAL_TYPES(common_type, ctx, "sub.out", CTYPE_IN, [&]() {
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx,
utils::extract_scalar(alpha, &alpha_val),
InvalidArgument, );

apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted - alpha_val * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALH_TYPES(out_type, ctx, "sub.out", CTYPE_OUT, [&]() {
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
SubInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, alpha_val, out);
});
});
});
Expand Down
63 changes: 18 additions & 45 deletions kernels/portable/cpu/op_bitwise_and.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
* LICENSE file in the root directory of this source tree.
*/

#include <cmath>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

#include <executorch/kernels/portable/cpu/pattern/bitwise_op.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
Expand All @@ -17,20 +19,6 @@ namespace torch {
namespace executor {
namespace native {

namespace {

template <typename CTYPE>
CTYPE bitwise_and(CTYPE a, CTYPE b) {
return a & b;
}

template <>
bool bitwise_and<bool>(bool a, bool b) {
return a && b;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& bitwise_and_Tensor_out(
Expand All @@ -55,38 +43,23 @@ Tensor& bitwise_and_Tensor_out(
Bool, a_type, ctx, "bitwise_and.Tensor_out", CTYPE_A, [&]() {
ET_SWITCH_INT_TYPES_AND(
Bool, b_type, ctx, "bitwise_and.Tensor_out", CTYPE_B, [&]() {
ET_SWITCH_INT_TYPES_AND(
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REAL_TYPES_AND(
Bool,
common_type,
out_type,
ctx,
"bitwise_and.Tensor_out",
CTYPE_IN,
CTYPE_OUT,
[&]() {
ET_SWITCH_REAL_TYPES_AND(
Bool,
out_type,
ctx,
"bitwise_and.Tensor_out",
CTYPE_OUT,
[&]() {
apply_binary_elementwise_fn<
CTYPE_A,
CTYPE_B,
CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted =
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_and(a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
internal::BitwiseOpInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
std::bit_and,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
});
Expand Down Expand Up @@ -142,8 +115,8 @@ Tensor& bitwise_and_Scalar_out(
static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted =
static_cast<CTYPE_IN>(val_b);
CTYPE_IN value =
bitwise_and(a_casted, b_casted);
CTYPE_IN value = std::bit_and<CTYPE_IN>()(
a_casted, b_casted);

return static_cast<CTYPE_OUT>(value);
},
Expand Down
Loading