Skip to content

Commit 4eb266d

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce eq/ne scalar op size & build time
Summary: These two scalar ops use promoteTypes, so we can use compile-time promotion right away. Differential Revision: D56744985
1 parent 94ebd30 commit 4eb266d

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

kernels/portable/cpu/pattern/comparison_op.h

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,24 +129,23 @@ Tensor& scalar_comparison_op_with_regular_promotion_out(
129129

130130
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, op_name, CTYPE_A, [&]() {
131131
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() {
132-
ET_SWITCH_REAL_TYPES_AND(
133-
Bool, common_type, ctx, op_name, CTYPE_IN, [&]() {
134-
ET_SWITCH_REAL_TYPES_AND(
135-
Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
136-
CTYPE_B val_b = 0;
137-
utils::extract_scalar(b, &val_b);
138-
apply_unary_map_fn(
139-
[val_b](const CTYPE_A val_a) {
140-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
141-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
142-
bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
143-
return static_cast<CTYPE_OUT>(value);
144-
},
145-
a.const_data_ptr<CTYPE_A>(),
146-
out.mutable_data_ptr<CTYPE_OUT>(),
147-
out.numel());
148-
});
149-
});
132+
using CTYPE_IN =
133+
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
134+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
135+
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, op_name, CTYPE_OUT, [&]() {
136+
CTYPE_B val_b = 0;
137+
utils::extract_scalar(b, &val_b);
138+
apply_unary_map_fn(
139+
[val_b](const CTYPE_A val_a) {
140+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
141+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
142+
bool value = OpFunc<CTYPE_IN>()(a_casted, b_casted);
143+
return static_cast<CTYPE_OUT>(value);
144+
},
145+
a.const_data_ptr<CTYPE_A>(),
146+
out.mutable_data_ptr<CTYPE_OUT>(),
147+
out.numel());
148+
});
150149
});
151150
});
152151

0 commit comments

Comments
 (0)