@@ -20,6 +20,49 @@ const T& min(const T& a, const T& b) {
20
20
return (b < a) ? b : a;
21
21
}
22
22
23
+ template <
24
+ bool can_cast,
25
+ typename CTYPE_A,
26
+ typename CTYPE_B,
27
+ typename CTYPE_IN,
28
+ typename CTYPE_OUT>
29
+ struct MinimumInner ;
30
+
31
+ template <
32
+ typename CTYPE_A,
33
+ typename CTYPE_B,
34
+ typename CTYPE_IN,
35
+ typename CTYPE_OUT>
36
+ struct MinimumInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37
+ static void run (const Tensor& a, const Tensor& b, Tensor& out) {
38
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39
+ [](const CTYPE_A val_a, const CTYPE_B val_b) {
40
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
41
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
42
+ CTYPE_IN value = min (a_casted, b_casted);
43
+
44
+ return static_cast <CTYPE_OUT>(value);
45
+ },
46
+ a,
47
+ b,
48
+ out);
49
+ }
50
+ };
51
+
52
+ struct ReportCanCastBug {
53
+ static void run (const Tensor&, const Tensor&, Tensor&) {
54
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
55
+ }
56
+ };
57
+
58
+ template <
59
+ typename CTYPE_A,
60
+ typename CTYPE_B,
61
+ typename CTYPE_IN,
62
+ typename CTYPE_OUT>
63
+ struct MinimumInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
64
+ : public ReportCanCastBug {};
65
+
23
66
} // namespace
24
67
25
68
Tensor& minimum_out (
@@ -44,22 +87,17 @@ Tensor& minimum_out(
44
87
45
88
ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, " minimum.out" , CTYPE_A, [&]() {
46
89
ET_SWITCH_REAL_TYPES_AND (Bool, b_type, ctx, " minimum.out" , CTYPE_B, [&]() {
90
+ using CTYPE_IN =
91
+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
92
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
47
93
ET_SWITCH_REAL_TYPES_AND (
48
- Bool, common_type, ctx, " minimum.out" , CTYPE_IN, [&]() {
49
- ET_SWITCH_REAL_TYPES_AND (
50
- Bool, out_type, ctx, " minimum.out" , CTYPE_OUT, [&]() {
51
- apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
52
- [](const CTYPE_A val_a, const CTYPE_B val_b) {
53
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
54
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
55
- CTYPE_IN value = min (a_casted, b_casted);
56
-
57
- return static_cast <CTYPE_OUT>(value);
58
- },
59
- a,
60
- b,
61
- out);
62
- });
94
+ Bool, out_type, ctx, " minimum.out" , CTYPE_OUT, [&]() {
95
+ MinimumInner<
96
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
97
+ CTYPE_A,
98
+ CTYPE_B,
99
+ CTYPE_IN,
100
+ CTYPE_OUT>::run (a, b, out);
63
101
});
64
102
});
65
103
});
0 commit comments