@@ -129,24 +129,23 @@ Tensor& scalar_comparison_op_with_regular_promotion_out(
129
129
130
130
ET_SWITCH_REAL_TYPES_AND (Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
131
131
ET_SWITCH_SCALAR_OBJ_TYPES (b_type, ctx, op_name, CTYPE_B, [&]() {
132
- ET_SWITCH_REAL_TYPES_AND (
133
- Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
134
- ET_SWITCH_REAL_TYPES_AND (
135
- Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
136
- CTYPE_B val_b = 0 ;
137
- utils::extract_scalar (b, &val_b);
138
- apply_unary_map_fn (
139
- [val_b](const CTYPE_A val_a) {
140
- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
141
- CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
142
- bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
143
- return static_cast <CTYPE_OUT>(value);
144
- },
145
- a.const_data_ptr <CTYPE_A>(),
146
- out.mutable_data_ptr <CTYPE_OUT>(),
147
- out.numel ());
148
- });
149
- });
132
+ using CTYPE_IN =
133
+ typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
134
+ ET_DCHECK (CppTypeToScalarType<CTYPE_IN>::value == common_type);
135
+ ET_SWITCH_REAL_TYPES_AND (Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
136
+ CTYPE_B val_b = 0 ;
137
+ utils::extract_scalar (b, &val_b);
138
+ apply_unary_map_fn (
139
+ [val_b](const CTYPE_A val_a) {
140
+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
141
+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(val_b);
142
+ bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
143
+ return static_cast <CTYPE_OUT>(value);
144
+ },
145
+ a.const_data_ptr <CTYPE_A>(),
146
+ out.mutable_data_ptr <CTYPE_OUT>(),
147
+ out.numel ());
148
+ });
150
149
});
151
150
});
152
151
0 commit comments