|
9 | 9 | #pragma once
|
10 | 10 |
|
11 | 11 | #include <c10/util/irange.h>
|
| 12 | +#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h> |
12 | 13 | #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
|
13 | 14 | #include <executorch/kernels/portable/cpu/util/dtype_util.h>
|
14 | 15 | #include <executorch/runtime/kernel/kernel_runtime_context.h>
|
@@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn(
|
121 | 122 | char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
|
122 | 123 |
|
123 | 124 | auto out_numel = out.numel();
|
124 |
| - for (const auto i : c10::irange(out_numel)) { |
125 |
| - size_t a_linear_index = i; |
126 |
| - size_t b_linear_index = i; |
127 |
| - |
128 |
| - if (any_is_broadcasted) { |
129 |
| - size_t out_indexes[kTensorDimensionLimit]; |
130 |
| - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); |
131 |
| - |
132 |
| - if (a_is_broadcasted) { |
133 |
| - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); |
134 |
| - } |
135 |
| - if (b_is_broadcasted) { |
136 |
| - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); |
137 |
| - } |
| 125 | + if (any_is_broadcasted) { |
| 126 | + for (const auto [out_index, a_index, b_index] : |
| 127 | + BroadcastIndexesRange<2>(out, a, b)) { |
| 128 | + auto result = compute_fun( |
| 129 | + load_a_to_common(&data_a[a_index * a_element_size]), |
| 130 | + load_b_to_common(&data_b[b_index * b_element_size])); |
| 131 | + store_common_to_out(result, &data_out[out_index * out_element_size]); |
| 132 | + } |
| 133 | + } else { |
| 134 | + for (const auto i : c10::irange(out_numel)) { |
| 135 | + size_t a_linear_index = i; |
| 136 | + size_t b_linear_index = i; |
| 137 | + |
| 138 | + auto result = compute_fun( |
| 139 | + load_a_to_common(&data_a[a_linear_index * a_element_size]), |
| 140 | + load_b_to_common(&data_b[b_linear_index * b_element_size])); |
| 141 | + store_common_to_out(result, &data_out[i * out_element_size]); |
138 | 142 | }
|
139 |
| - |
140 |
| - auto result = compute_fun( |
141 |
| - load_a_to_common(&data_a[a_linear_index * a_element_size]), |
142 |
| - load_b_to_common(&data_b[b_linear_index * b_element_size])); |
143 |
| - store_common_to_out(result, &data_out[i * out_element_size]); |
144 | 143 | }
|
145 | 144 | }
|
146 | 145 |
|
@@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn(
|
211 | 210 | char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
|
212 | 211 |
|
213 | 212 | auto out_numel = out.numel();
|
214 |
| - for (const auto i : c10::irange(out_numel)) { |
215 |
| - size_t a_linear_index = i; |
216 |
| - size_t b_linear_index = i; |
217 |
| - size_t c_linear_index = i; |
218 |
| - |
219 |
| - if (any_is_broadcasted) { |
220 |
| - size_t out_indexes[kTensorDimensionLimit]; |
221 |
| - delinearize_index(i, out, out_indexes, kTensorDimensionLimit); |
222 |
| - |
223 |
| - if (a_is_broadcasted) { |
224 |
| - a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a); |
225 |
| - } |
226 |
| - if (b_is_broadcasted) { |
227 |
| - b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b); |
228 |
| - } |
229 |
| - if (c_is_broadcasted) { |
230 |
| - c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c); |
231 |
| - } |
| 213 | + if (any_is_broadcasted) { |
| 214 | + for (const auto [out_index, a_index, b_index, c_index] : |
| 215 | + BroadcastIndexesRange<3>(out, a, b, c)) { |
| 216 | + auto result = compute_fun( |
| 217 | + load_a_to_common(&data_a[a_index * a_element_size]), |
| 218 | + load_b_to_common(&data_b[b_index * b_element_size]), |
| 219 | + load_c_to_common(&data_c[c_index * c_element_size])); |
| 220 | + store_common_to_out(result, &data_out[out_index * out_element_size]); |
| 221 | + } |
| 222 | + } else { |
| 223 | + for (const auto i : c10::irange(out_numel)) { |
| 224 | + size_t a_linear_index = i; |
| 225 | + size_t b_linear_index = i; |
| 226 | + size_t c_linear_index = i; |
| 227 | + |
| 228 | + auto result = compute_fun( |
| 229 | + load_a_to_common(&data_a[a_linear_index * a_element_size]), |
| 230 | + load_b_to_common(&data_b[b_linear_index * b_element_size]), |
| 231 | + load_c_to_common(&data_c[c_linear_index * c_element_size])); |
| 232 | + store_common_to_out(result, &data_out[i * out_element_size]); |
232 | 233 | }
|
233 |
| - |
234 |
| - auto result = compute_fun( |
235 |
| - load_a_to_common(&data_a[a_linear_index * a_element_size]), |
236 |
| - load_b_to_common(&data_b[b_linear_index * b_element_size]), |
237 |
| - load_c_to_common(&data_c[c_linear_index * c_element_size])); |
238 |
| - store_common_to_out(result, &data_out[i * out_element_size]); |
239 | 234 | }
|
240 | 235 | }
|
241 | 236 |
|
|
0 commit comments