Skip to content

Commit e4ab6c2

Browse files
swolchokZonglin Peng
authored and
Zonglin Peng
committed
Deploy BroadcastIndexesRange (#8865)
1 parent f2ed70e commit e4ab6c2

File tree

3 files changed

+61
-79
lines changed

3 files changed

+61
-79
lines changed

kernels/portable/cpu/util/broadcast_util.h

+20-36
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <c10/util/irange.h>
12+
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1213
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1314
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1415

@@ -290,23 +291,18 @@ inline void apply_binary_elementwise_fn(
290291
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
291292
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
292293

293-
for (const auto i : c10::irange(out.numel())) {
294-
size_t a_linear_index = i;
295-
size_t b_linear_index = i;
296-
297-
if (any_is_broadcasted) {
298-
size_t out_indexes[kTensorDimensionLimit];
299-
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
300-
301-
if (a_is_broadcasted) {
302-
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
303-
}
304-
if (b_is_broadcasted) {
305-
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
306-
}
294+
if (any_is_broadcasted) {
295+
for (const auto [out_index, a_index, b_index] :
296+
BroadcastIndexesRange<2>(out, a, b)) {
297+
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
307298
}
299+
} else {
300+
for (const auto i : c10::irange(out.numel())) {
301+
size_t a_linear_index = i;
302+
size_t b_linear_index = i;
308303

309-
data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
304+
data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
305+
}
310306
}
311307
}
312308

@@ -338,28 +334,16 @@ inline void apply_ternary_elementwise_fn(
338334
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
339335
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
340336

341-
for (const auto i : c10::irange(out.numel())) {
342-
size_t a_linear_index = i;
343-
size_t b_linear_index = i;
344-
size_t c_linear_index = i;
345-
346-
if (any_is_broadcasted) {
347-
size_t out_indexes[kTensorDimensionLimit];
348-
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
349-
350-
if (a_is_broadcasted) {
351-
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
352-
}
353-
if (b_is_broadcasted) {
354-
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
355-
}
356-
if (c_is_broadcasted) {
357-
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
358-
}
337+
if (any_is_broadcasted) {
338+
for (const auto [out_index, a_index, b_index, c_index] :
339+
BroadcastIndexesRange<3>(out, a, b, c)) {
340+
data_out[out_index] =
341+
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
342+
}
343+
} else {
344+
for (const auto i : c10::irange(out.numel())) {
345+
data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]);
359346
}
360-
361-
data_out[i] = compute_fun(
362-
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
363347
}
364348
}
365349

kernels/portable/cpu/util/elementwise_util.h

+38-43
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <c10/util/irange.h>
12+
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
1213
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1314
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
1415
#include <executorch/runtime/kernel/kernel_runtime_context.h>
@@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn(
121122
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
122123

123124
auto out_numel = out.numel();
124-
for (const auto i : c10::irange(out_numel)) {
125-
size_t a_linear_index = i;
126-
size_t b_linear_index = i;
127-
128-
if (any_is_broadcasted) {
129-
size_t out_indexes[kTensorDimensionLimit];
130-
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
131-
132-
if (a_is_broadcasted) {
133-
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
134-
}
135-
if (b_is_broadcasted) {
136-
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
137-
}
125+
if (any_is_broadcasted) {
126+
for (const auto [out_index, a_index, b_index] :
127+
BroadcastIndexesRange<2>(out, a, b)) {
128+
auto result = compute_fun(
129+
load_a_to_common(&data_a[a_index * a_element_size]),
130+
load_b_to_common(&data_b[b_index * b_element_size]));
131+
store_common_to_out(result, &data_out[out_index * out_element_size]);
132+
}
133+
} else {
134+
for (const auto i : c10::irange(out_numel)) {
135+
size_t a_linear_index = i;
136+
size_t b_linear_index = i;
137+
138+
auto result = compute_fun(
139+
load_a_to_common(&data_a[a_linear_index * a_element_size]),
140+
load_b_to_common(&data_b[b_linear_index * b_element_size]));
141+
store_common_to_out(result, &data_out[i * out_element_size]);
138142
}
139-
140-
auto result = compute_fun(
141-
load_a_to_common(&data_a[a_linear_index * a_element_size]),
142-
load_b_to_common(&data_b[b_linear_index * b_element_size]));
143-
store_common_to_out(result, &data_out[i * out_element_size]);
144143
}
145144
}
146145

@@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn(
211210
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
212211

213212
auto out_numel = out.numel();
214-
for (const auto i : c10::irange(out_numel)) {
215-
size_t a_linear_index = i;
216-
size_t b_linear_index = i;
217-
size_t c_linear_index = i;
218-
219-
if (any_is_broadcasted) {
220-
size_t out_indexes[kTensorDimensionLimit];
221-
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);
222-
223-
if (a_is_broadcasted) {
224-
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
225-
}
226-
if (b_is_broadcasted) {
227-
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
228-
}
229-
if (c_is_broadcasted) {
230-
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
231-
}
213+
if (any_is_broadcasted) {
214+
for (const auto [out_index, a_index, b_index, c_index] :
215+
BroadcastIndexesRange<3>(out, a, b, c)) {
216+
auto result = compute_fun(
217+
load_a_to_common(&data_a[a_index * a_element_size]),
218+
load_b_to_common(&data_b[b_index * b_element_size]),
219+
load_c_to_common(&data_c[c_index * c_element_size]));
220+
store_common_to_out(result, &data_out[out_index * out_element_size]);
221+
}
222+
} else {
223+
for (const auto i : c10::irange(out_numel)) {
224+
size_t a_linear_index = i;
225+
size_t b_linear_index = i;
226+
size_t c_linear_index = i;
227+
228+
auto result = compute_fun(
229+
load_a_to_common(&data_a[a_linear_index * a_element_size]),
230+
load_b_to_common(&data_b[b_linear_index * b_element_size]),
231+
load_c_to_common(&data_c[c_linear_index * c_element_size]));
232+
store_common_to_out(result, &data_out[i * out_element_size]);
232233
}
233-
234-
auto result = compute_fun(
235-
load_a_to_common(&data_a[a_linear_index * a_element_size]),
236-
load_b_to_common(&data_b[b_linear_index * b_element_size]),
237-
load_c_to_common(&data_c[c_linear_index * c_element_size]));
238-
store_common_to_out(result, &data_out[i * out_element_size]);
239234
}
240235
}
241236

kernels/portable/cpu/util/targets.bzl

+3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def define_common_targets():
7070
exported_headers = [
7171
"broadcast_util.h",
7272
],
73+
exported_deps = [
74+
":broadcast_indexes_range",
75+
],
7376
deps = [
7477
":repeat_util",
7578
"//executorch/runtime/kernel:kernel_includes",

0 commit comments

Comments
 (0)