Skip to content

Commit 47d080e

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce max/min size & build time
Summary: Yet another smaller pair of ops. Differential Revision: D56807402
1 parent 4eb266d commit 47d080e

File tree

2 files changed

+106
-29
lines changed

2 files changed

+106
-29
lines changed

kernels/portable/cpu/op_maximum.cpp

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,49 @@ const T& max(const T& a, const T& b) {
2020
return (b > a) ? b : a;
2121
}
2222

23+
template <
24+
bool can_cast,
25+
typename CTYPE_A,
26+
typename CTYPE_B,
27+
typename CTYPE_IN,
28+
typename CTYPE_OUT>
29+
struct MaximumInner;
30+
31+
template <
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct MaximumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
40+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
41+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
42+
CTYPE_IN value = max(a_casted, b_casted);
43+
44+
return static_cast<CTYPE_OUT>(value);
45+
},
46+
a,
47+
b,
48+
out);
49+
}
50+
};
51+
52+
struct ReportCanCastBug {
53+
static void run(const Tensor&, const Tensor&, Tensor&) {
54+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
55+
}
56+
};
57+
58+
template <
59+
typename CTYPE_A,
60+
typename CTYPE_B,
61+
typename CTYPE_IN,
62+
typename CTYPE_OUT>
63+
struct MaximumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
64+
: public ReportCanCastBug {};
65+
2366
} // namespace
2467

2568
Tensor& maximum_out(
@@ -44,20 +87,16 @@ Tensor& maximum_out(
4487

4588
ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() {
4689
ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() {
47-
ET_SWITCH_REALB_TYPES(common_type, ctx, "maximum.out", CTYPE_IN, [&]() {
48-
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
49-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
50-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
51-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
52-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
53-
CTYPE_IN value = max(a_casted, b_casted);
54-
55-
return static_cast<CTYPE_OUT>(value);
56-
},
57-
a,
58-
b,
59-
out);
60-
});
90+
using CTYPE_IN = typename torch::executor::
91+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
92+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
93+
ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() {
94+
MaximumInner<
95+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
96+
CTYPE_A,
97+
CTYPE_B,
98+
CTYPE_IN,
99+
CTYPE_OUT>::run(a, b, out);
61100
});
62101
});
63102
});

kernels/portable/cpu/op_minimum.cpp

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,49 @@ const T& min(const T& a, const T& b) {
2020
return (b < a) ? b : a;
2121
}
2222

23+
template <
24+
bool can_cast,
25+
typename CTYPE_A,
26+
typename CTYPE_B,
27+
typename CTYPE_IN,
28+
typename CTYPE_OUT>
29+
struct MinimumInner;
30+
31+
template <
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct MinimumInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
38+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
40+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
41+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
42+
CTYPE_IN value = min(a_casted, b_casted);
43+
44+
return static_cast<CTYPE_OUT>(value);
45+
},
46+
a,
47+
b,
48+
out);
49+
}
50+
};
51+
52+
struct ReportCanCastBug {
53+
static void run(const Tensor&, const Tensor&, Tensor&) {
54+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
55+
}
56+
};
57+
58+
template <
59+
typename CTYPE_A,
60+
typename CTYPE_B,
61+
typename CTYPE_IN,
62+
typename CTYPE_OUT>
63+
struct MinimumInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
64+
: public ReportCanCastBug {};
65+
2366
} // namespace
2467

2568
Tensor& minimum_out(
@@ -44,22 +87,17 @@ Tensor& minimum_out(
4487

4588
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "minimum.out", CTYPE_A, [&]() {
4689
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "minimum.out", CTYPE_B, [&]() {
90+
using CTYPE_IN =
91+
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
92+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
4793
ET_SWITCH_REAL_TYPES_AND(
48-
Bool, common_type, ctx, "minimum.out", CTYPE_IN, [&]() {
49-
ET_SWITCH_REAL_TYPES_AND(
50-
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
51-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
52-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
53-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
54-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
55-
CTYPE_IN value = min(a_casted, b_casted);
56-
57-
return static_cast<CTYPE_OUT>(value);
58-
},
59-
a,
60-
b,
61-
out);
62-
});
94+
Bool, out_type, ctx, "minimum.out", CTYPE_OUT, [&]() {
95+
MinimumInner<
96+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
97+
CTYPE_A,
98+
CTYPE_B,
99+
CTYPE_IN,
100+
CTYPE_OUT>::run(a, b, out);
63101
});
64102
});
65103
});

0 commit comments

Comments
 (0)