@@ -19,6 +19,60 @@ namespace native {
19
19
20
20
using Tensor = exec_aten::Tensor;
21
21
22
+ namespace {
23
+ template <
24
+ bool can_cast,
25
+ typename CTYPE_A,
26
+ typename CTYPE_B,
27
+ typename CTYPE_IN,
28
+ typename CTYPE_OUT>
29
+ struct FmodInner ;
30
+
31
+ template <
32
+ typename CTYPE_A,
33
+ typename CTYPE_B,
34
+ typename CTYPE_IN,
35
+ typename CTYPE_OUT>
36
+ struct FmodInner <true , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37
+ static void
38
+ run (const Tensor& a, const Tensor& b, Tensor& out, bool & div_by_zero_error) {
39
+ apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40
+ // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41
+ [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
42
+ if (is_integral_type<CTYPE_IN, /* includeBool=*/ true >::value) {
43
+ if (val_b == 0 ) {
44
+ div_by_zero_error = true ;
45
+ return static_cast <CTYPE_OUT>(0 );
46
+ }
47
+ }
48
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
49
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
50
+ CTYPE_IN value = std::fmod (a_casted, b_casted);
51
+
52
+ return static_cast <CTYPE_OUT>(value);
53
+ },
54
+ a,
55
+ b,
56
+ out);
57
+ }
58
+ };
59
+
60
+ struct ReportCanCastBug {
61
+ static void run (const Tensor&, const Tensor&, Tensor&, bool &) {
62
+ ET_DCHECK_MSG (false , " BUG: canCast should have been checked above" );
63
+ }
64
+ };
65
+
66
+ template <
67
+ typename CTYPE_A,
68
+ typename CTYPE_B,
69
+ typename CTYPE_IN,
70
+ typename CTYPE_OUT>
71
+ struct FmodInner <false , CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
72
+ : public ReportCanCastBug {};
73
+
74
+ } // namespace
75
+
22
76
Tensor& fmod_Tensor_out (
23
77
RuntimeContext& ctx,
24
78
const Tensor& a,
@@ -44,35 +98,18 @@ Tensor& fmod_Tensor_out(
44
98
Bool, a_type, ctx, " fmod.Tensor_out" , CTYPE_A, [&]() {
45
99
ET_SWITCH_REAL_TYPES_AND (
46
100
Bool, b_type, ctx, " fmod.Tensor_out" , CTYPE_B, [&]() {
101
+ using CTYPE_IN = typename torch::executor::
102
+ promote_types<CTYPE_A, CTYPE_B>::type;
103
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
47
104
ET_SWITCH_REAL_TYPES (
48
- common_type, ctx, " fmod.Tensor_out" , CTYPE_IN, [&]() {
49
- ET_SWITCH_REAL_TYPES (
50
- out_type, ctx, " fmod.Tensor_out" , CTYPE_OUT, [&]() {
51
- apply_binary_elementwise_fn<
52
- CTYPE_A,
53
- CTYPE_B,
54
- CTYPE_OUT>(
55
- [common_type, &div_by_zero_error](
56
- const CTYPE_A val_a, const CTYPE_B val_b) {
57
- if (isIntegralType (
58
- common_type, /* includeBool=*/ true )) {
59
- if (val_b == 0 ) {
60
- div_by_zero_error = true ;
61
- return static_cast <CTYPE_OUT>(0 );
62
- }
63
- }
64
- CTYPE_IN a_casted =
65
- static_cast <CTYPE_IN>(val_a);
66
- CTYPE_IN b_casted =
67
- static_cast <CTYPE_IN>(val_b);
68
- CTYPE_IN value = std::fmod (a_casted, b_casted);
69
-
70
- return static_cast <CTYPE_OUT>(value);
71
- },
72
- a,
73
- b,
74
- out);
75
- });
105
+ out_type, ctx, " fmod.Tensor_out" , CTYPE_OUT, [&]() {
106
+ FmodInner<
107
+ !std::is_same<CTYPE_IN, bool >::value &&
108
+ can_cast<CTYPE_IN, CTYPE_OUT>::value,
109
+ CTYPE_A,
110
+ CTYPE_B,
111
+ CTYPE_IN,
112
+ CTYPE_OUT>::run (a, b, out, div_by_zero_error);
76
113
});
77
114
});
78
115
});
0 commit comments