Skip to content

Commit fe97d11

Browse files
swolchokfacebook-github-bot
authored andcommitted
Use compile-time promotion to reduce fmod size & build time (#3456)
Summary: Almost done with Tensor ops that can benefit from compile-time promotion! Reviewed By: manuelcandales Differential Revision: D56835200
1 parent 47fb0cf commit fe97d11

File tree

2 files changed

+78
-28
lines changed

2 files changed

+78
-28
lines changed

kernels/portable/cpu/op_fmod.cpp

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,60 @@ namespace native {
1919

2020
using Tensor = exec_aten::Tensor;
2121

22+
namespace {
23+
template <
24+
bool can_cast,
25+
typename CTYPE_A,
26+
typename CTYPE_B,
27+
typename CTYPE_IN,
28+
typename CTYPE_OUT>
29+
struct FmodInner;
30+
31+
template <
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct FmodInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void
38+
run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) {
39+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41+
[&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) {
42+
if (is_integral_type<CTYPE_IN, /*includeBool=*/true>::value) {
43+
if (val_b == 0) {
44+
div_by_zero_error = true;
45+
return static_cast<CTYPE_OUT>(0);
46+
}
47+
}
48+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
49+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
50+
CTYPE_IN value = std::fmod(a_casted, b_casted);
51+
52+
return static_cast<CTYPE_OUT>(value);
53+
},
54+
a,
55+
b,
56+
out);
57+
}
58+
};
59+
60+
struct ReportCanCastBug {
61+
static void run(const Tensor&, const Tensor&, Tensor&, bool&) {
62+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
63+
}
64+
};
65+
66+
template <
67+
typename CTYPE_A,
68+
typename CTYPE_B,
69+
typename CTYPE_IN,
70+
typename CTYPE_OUT>
71+
struct FmodInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
72+
: public ReportCanCastBug {};
73+
74+
} // namespace
75+
2276
Tensor& fmod_Tensor_out(
2377
RuntimeContext& ctx,
2478
const Tensor& a,
@@ -44,35 +98,18 @@ Tensor& fmod_Tensor_out(
4498
Bool, a_type, ctx, "fmod.Tensor_out", CTYPE_A, [&]() {
4599
ET_SWITCH_REAL_TYPES_AND(
46100
Bool, b_type, ctx, "fmod.Tensor_out", CTYPE_B, [&]() {
101+
using CTYPE_IN = typename torch::executor::
102+
promote_types<CTYPE_A, CTYPE_B>::type;
103+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
47104
ET_SWITCH_REAL_TYPES(
48-
common_type, ctx, "fmod.Tensor_out", CTYPE_IN, [&]() {
49-
ET_SWITCH_REAL_TYPES(
50-
out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() {
51-
apply_binary_elementwise_fn<
52-
CTYPE_A,
53-
CTYPE_B,
54-
CTYPE_OUT>(
55-
[common_type, &div_by_zero_error](
56-
const CTYPE_A val_a, const CTYPE_B val_b) {
57-
if (isIntegralType(
58-
common_type, /*includeBool=*/true)) {
59-
if (val_b == 0) {
60-
div_by_zero_error = true;
61-
return static_cast<CTYPE_OUT>(0);
62-
}
63-
}
64-
CTYPE_IN a_casted =
65-
static_cast<CTYPE_IN>(val_a);
66-
CTYPE_IN b_casted =
67-
static_cast<CTYPE_IN>(val_b);
68-
CTYPE_IN value = std::fmod(a_casted, b_casted);
69-
70-
return static_cast<CTYPE_OUT>(value);
71-
},
72-
a,
73-
b,
74-
out);
75-
});
105+
out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() {
106+
FmodInner<
107+
!std::is_same<CTYPE_IN, bool>::value &&
108+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
109+
CTYPE_A,
110+
CTYPE_B,
111+
CTYPE_IN,
112+
CTYPE_OUT>::run(a, b, out, div_by_zero_error);
76113
});
77114
});
78115
});

kernels/test/op_fmod_test.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,16 @@ class OpFmodTest : public OperatorTest {
3232
return torch::executor::aten::fmod_outf(context_, self, other, out);
3333
}
3434
};
35+
36+
TEST_F(OpFmodTest, SmokeTest) {
37+
TensorFactory<ScalarType::Long> tfDouble;
38+
TensorFactory<ScalarType::Long> tfLong;
39+
TensorFactory<ScalarType::Int> tfInt;
40+
41+
Tensor self = tfLong.full({2, 2}, 46);
42+
Tensor other = tfInt.full({2, 2}, 4);
43+
Tensor out = tfDouble.zeros({2, 2});
44+
Tensor out_expected = tfDouble.full({2, 2}, 2.0);
45+
op_fmod_tensor_out(self, other, out);
46+
EXPECT_TENSOR_CLOSE(out, out_expected);
47+
}

0 commit comments

Comments
 (0)