@@ -58,17 +58,17 @@ Tensor& div_out(
5858 static constexpr const char op_name[] = " div.out" ;
5959
6060 ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61- utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
62- [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
63- return val_a / val_b;
64- },
61+ utils::apply_bitensor_elementwise_fn<
62+ CTYPE_COMPUTE,
63+ op_name,
64+ utils::SupportedTensorDtypes::FLOATHBF16>(
65+ [](const auto val_a, const auto val_b) { return val_a / val_b; },
6566 ctx,
6667 a,
6768 utils::SupportedTensorDtypes::REALHBBF16,
6869 b,
6970 utils::SupportedTensorDtypes::REALHBBF16,
70- out,
71- utils::SupportedTensorDtypes::FLOATHBF16);
71+ out);
7272 });
7373
7474 return out;
@@ -122,9 +122,13 @@ Tensor& div_out_mode(
122122 bool div_by_zero_error = false ;
123123
124124 ET_SWITCH_REAL_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
125- utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
125+ utils::apply_bitensor_elementwise_fn<
126+ CTYPE_COMPUTE,
127+ op_name,
128+ utils::SupportedTensorDtypes::REALHBF16>(
126129 [mode_is_trunc, &div_by_zero_error](
127130 const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
131+ // TODO: rewrite this to be vectorization-capable.
128132 if (is_integral_type<CTYPE_COMPUTE, /* includeBool=*/ true >::value) {
129133 if (val_b == 0 ) {
130134 div_by_zero_error = true ;
@@ -146,8 +150,7 @@ Tensor& div_out_mode(
146150 utils::SupportedTensorDtypes::REALHBBF16,
147151 b,
148152 utils::SupportedTensorDtypes::REALHBBF16,
149- out,
150- utils::SupportedTensorDtypes::REALHBF16);
153+ out);
151154 });
152155
153156 ET_KERNEL_CHECK_MSG (
@@ -188,13 +191,15 @@ Tensor& div_scalar_out(
188191
189192 ET_SWITCH_FLOAT_TYPES (compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
190193 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
191- utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
192- [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
194+ utils::apply_unitensor_elementwise_fn<
195+ CTYPE_COMPUTE,
196+ op_name,
197+ utils::SupportedTensorDtypes::SAME_AS_COMMON>(
198+ [val_b](const auto val_a) { return val_a / val_b; },
193199 ctx,
194200 a,
195201 utils::SupportedTensorDtypes::REALHBBF16,
196- out,
197- utils::SupportedTensorDtypes::SAME_AS_COMMON);
202+ out);
198203 });
199204
200205 return out;
0 commit comments