Skip to content

Commit 895efcb

Browse files
authored
Support alpha in scalar add/sub cases
Differential Revision: D71949902 Pull Request resolved: #9703
1 parent 879b94f commit 895efcb

File tree

3 files changed

+8
-22
lines changed

3 files changed

+8
-22
lines changed

backends/cadence/hifi/operators/op_add.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,15 @@ Tensor& add_out(
143143

144144
if ((a_dim == 0) && float_types) {
145145
for (int i = 0; i < b.numel(); i++)
146-
out.mutable_data_ptr<float>()[i] =
147-
a.const_data_ptr<float>()[0] + b.const_data_ptr<float>()[i];
146+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[0] +
147+
alpha_val * b.const_data_ptr<float>()[i];
148148
return out;
149149
}
150150
if ((b_dim == 0) && float_types) {
151+
// Precompute the value of b * alpha since it's a constant.
152+
const float val_b = alpha_val * b.const_data_ptr<float>()[0];
151153
for (int i = 0; i < a.numel(); i++)
152-
out.mutable_data_ptr<float>()[i] =
153-
a.const_data_ptr<float>()[i] + b.const_data_ptr<float>()[0];
154+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] + val_b;
154155
return out;
155156
}
156157

backends/cadence/hifi/operators/op_div.cpp

-16
Original file line numberDiff line numberDiff line change
@@ -214,22 +214,6 @@ Tensor& div_out_mode(
214214
if ((a_type != ScalarType::Float) || (b_type != ScalarType::Float))
215215
optimized = 0;
216216

217-
bool float_types =
218-
(a_type == ScalarType::Float) && (b_type == ScalarType::Float);
219-
220-
if ((a_dim == 0) && float_types) {
221-
for (int i = 0; i < b.numel(); i++)
222-
out.mutable_data_ptr<float>()[i] =
223-
a.const_data_ptr<float>()[0] / b.const_data_ptr<float>()[i];
224-
return out;
225-
}
226-
if ((b_dim == 0) && float_types) {
227-
for (int i = 0; i < a.numel(); i++)
228-
out.mutable_data_ptr<float>()[i] =
229-
a.const_data_ptr<float>()[i] / b.const_data_ptr<float>()[0];
230-
return out;
231-
}
232-
233217
if ((broadcast == 1) && (max_dim > kNnlibMaxDim))
234218
optimized = 0;
235219
int mode_val = -1;

backends/cadence/hifi/operators/op_sub.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ Tensor& sub_out(
143143
return out;
144144
}
145145
if ((b_dim == 0) && float_types) {
146+
// Precompute the value of b * alpha since it's a constant.
147+
const float val_b = alpha_val * b.const_data_ptr<float>()[0];
146148
for (int i = 0; i < a.numel(); i++)
147-
out.mutable_data_ptr<float>()[i] =
148-
a.const_data_ptr<float>()[i] - b.const_data_ptr<float>()[0];
149+
out.mutable_data_ptr<float>()[i] = a.const_data_ptr<float>()[i] - val_b;
149150
return out;
150151
}
151152

0 commit comments

Comments
 (0)