@@ -58,17 +58,17 @@ Tensor& div_out(
58
58
static constexpr const char op_name[] = " div.out" ;
59
59
60
60
ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61
- utils::apply_bitensor_elementwise_fn<
62
- CTYPE_COMPUTE,
63
- op_name,
64
- utils::SupportedTensorDtypes::FLOATHBF16>(
65
- [](const auto val_a, const auto val_b) { return val_a / val_b; },
61
+ utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62
+ [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63
+ return val_a / val_b;
64
+ },
66
65
ctx,
67
66
a,
68
67
utils::SupportedTensorDtypes::REALHBBF16,
69
68
b,
70
69
utils::SupportedTensorDtypes::REALHBBF16,
71
- out);
70
+ out,
71
+ utils::SupportedTensorDtypes::FLOATHBF16);
72
72
});
73
73
74
74
return out;
@@ -122,13 +122,9 @@ Tensor& div_out_mode(
122
122
bool div_by_zero_error = false ;
123
123
124
124
ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125
- utils::apply_bitensor_elementwise_fn<
126
- CTYPE_COMPUTE,
127
- op_name,
128
- utils::SupportedTensorDtypes::REALHBF16>(
125
+ utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
129
126
[mode_is_trunc, &div_by_zero_error](
130
127
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
131
- // TODO: rewrite this to be vectorization-capable.
132
128
if (is_integral_type<CTYPE_COMPUTE, /* includeBool=*/ true >::value) {
133
129
if (val_b == 0 ) {
134
130
div_by_zero_error = true ;
@@ -150,7 +146,8 @@ Tensor& div_out_mode(
150
146
utils::SupportedTensorDtypes::REALHBBF16,
151
147
b,
152
148
utils::SupportedTensorDtypes::REALHBBF16,
153
- out);
149
+ out,
150
+ utils::SupportedTensorDtypes::REALHBF16);
154
151
});
155
152
156
153
ET_KERNEL_CHECK_MSG (
@@ -191,15 +188,13 @@ Tensor& div_scalar_out(
191
188
192
189
ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
193
190
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
194
- utils::apply_unitensor_elementwise_fn<
195
- CTYPE_COMPUTE,
196
- op_name,
197
- utils::SupportedTensorDtypes::SAME_AS_COMMON>(
198
- [val_b](const auto val_a) { return val_a / val_b; },
191
+ utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192
+ [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
199
193
ctx,
200
194
a,
201
195
utils::SupportedTensorDtypes::REALHBBF16,
202
- out);
196
+ out,
197
+ utils::SupportedTensorDtypes::SAME_AS_COMMON);
203
198
});
204
199
205
200
return out;
0 commit comments