Skip to content

Deploy BroadcastIndexesRange #8865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 6, 2025
56 changes: 20 additions & 36 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

Expand Down Expand Up @@ -290,23 +291,18 @@ inline void apply_binary_elementwise_fn(
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index] :
BroadcastIndexesRange<2>(out, a, b)) {
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;

data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
}
}
}

Expand Down Expand Up @@ -338,28 +334,16 @@ inline void apply_ternary_elementwise_fn(
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (c_is_broadcasted) {
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index, c_index] :
BroadcastIndexesRange<3>(out, a, b, c)) {
data_out[out_index] =
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]);
}

data_out[i] = compute_fun(
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
}
}

Expand Down
81 changes: 38 additions & 43 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>
Expand Down Expand Up @@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn(
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index] :
BroadcastIndexesRange<2>(out, a, b)) {
auto result = compute_fun(
load_a_to_common(&data_a[a_index * a_element_size]),
load_b_to_common(&data_b[b_index * b_element_size]));
store_common_to_out(result, &data_out[out_index * out_element_size]);
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}

Expand Down Expand Up @@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn(
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (c_is_broadcasted) {
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index, c_index] :
BroadcastIndexesRange<3>(out, a, b, c)) {
auto result = compute_fun(
load_a_to_common(&data_a[a_index * a_element_size]),
load_b_to_common(&data_b[b_index * b_element_size]),
load_c_to_common(&data_c[c_index * c_element_size]));
store_common_to_out(result, &data_out[out_index * out_element_size]);
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}

Expand Down
3 changes: 3 additions & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def define_common_targets():
exported_headers = [
"broadcast_util.h",
],
exported_deps = [
":broadcast_indexes_range",
],
deps = [
":repeat_util",
"//executorch/runtime/kernel:kernel_includes",
Expand Down
Loading