Skip to content

Add BroadcastIndexesRange tests with dims of size 1 in output #8964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions kernels/optimized/cpu/op_where.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <iostream>

namespace torch {
namespace executor {
namespace native {

Tensor& opt_where_out(
KernelRuntimeContext& ctx,
const Tensor& cond,
const Tensor& a,
const Tensor& b,
Tensor& out) {
// Common Dtype
ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());

// Check Common Dtype
ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);

// Resize
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, cond, out) == Error::Ok,
InvalidArgument,
out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "where.self_out";

if (a.scalar_type() == b.scalar_type() &&
a.scalar_type() == out.scalar_type() && a.scalar_type() == compute_type &&
// Using a Byte tensor for cond has been deprecated for a long time.
cond.scalar_type() == ScalarType::Bool) {
auto out_numel = out.numel();
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes());
const bool any_is_broadcasted =
(a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted);
const CTYPE_COMPUTE* const data_a = a.const_data_ptr<CTYPE_COMPUTE>();
const CTYPE_COMPUTE* const data_b = b.const_data_ptr<CTYPE_COMPUTE>();
const bool* const data_cond = cond.const_data_ptr<bool>();
CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index, cond_index] :
BroadcastIndexesRange<3>(out, a, b, cond)) {
data_out[out_index] =
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
}
} else {
for (const auto i : c10::irange(out_numel)) {
data_out[i] = data_cond[i] ? data_a[i] : data_b[i];
}
}
});
} else {
// Fall back for mixed dtype to keep code size and compile time
// reasonable.
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a,
const CTYPE_COMPUTE val_b,
const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
cond,
utils::SupportedTensorDtypes::BOOL_OR_BYTE,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
});
}

return out;
}

} // namespace native
} // namespace executor
} // namespace torch
6 changes: 6 additions & 0 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ _OPTIMIZED_ATEN_OPS = (
"//executorch/kernels/portable/cpu/util:broadcast_util",
],
),
op_target(
name = "op_where",
deps = [
"//executorch/kernels/portable/cpu/util:elementwise_util",
],
),
)


Expand Down
5 changes: 5 additions & 0 deletions kernels/optimized/optimized.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,8 @@
kernels:
- arg_meta: null
kernel_name: torch::executor::opt_sub_scalar_out

- op: where.self_out
kernels:
- arg_meta: null
kernel_name: torch::executor::opt_where_out
56 changes: 20 additions & 36 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>

Expand Down Expand Up @@ -290,23 +291,18 @@ inline void apply_binary_elementwise_fn(
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index] :
BroadcastIndexesRange<2>(out, a, b)) {
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;

data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
}
}
}

Expand Down Expand Up @@ -338,28 +334,16 @@ inline void apply_ternary_elementwise_fn(
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (c_is_broadcasted) {
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index, c_index] :
BroadcastIndexesRange<3>(out, a, b, c)) {
data_out[out_index] =
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]);
}

data_out[i] = compute_fun(
data_a[a_linear_index], data_b[b_linear_index], data_c[c_linear_index]);
}
}

Expand Down
81 changes: 38 additions & 43 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include <c10/util/irange.h>
#include <executorch/kernels/portable/cpu/util/broadcast_indexes_range.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>
Expand Down Expand Up @@ -121,26 +122,24 @@ inline void apply_bitensor_elementwise_fn(
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index] :
BroadcastIndexesRange<2>(out, a, b)) {
auto result = compute_fun(
load_a_to_common(&data_a[a_index * a_element_size]),
load_b_to_common(&data_b[b_index * b_element_size]));
store_common_to_out(result, &data_out[out_index * out_element_size]);
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}

Expand Down Expand Up @@ -211,31 +210,27 @@ inline void apply_tritensor_elementwise_fn(
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());

auto out_numel = out.numel();
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

if (any_is_broadcasted) {
size_t out_indexes[kTensorDimensionLimit];
delinearize_index(i, out, out_indexes, kTensorDimensionLimit);

if (a_is_broadcasted) {
a_linear_index = linearize_access_indexes(out_indexes, out.dim(), a);
}
if (b_is_broadcasted) {
b_linear_index = linearize_access_indexes(out_indexes, out.dim(), b);
}
if (c_is_broadcasted) {
c_linear_index = linearize_access_indexes(out_indexes, out.dim(), c);
}
if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index, c_index] :
BroadcastIndexesRange<3>(out, a, b, c)) {
auto result = compute_fun(
load_a_to_common(&data_a[a_index * a_element_size]),
load_b_to_common(&data_b[b_index * b_element_size]),
load_c_to_common(&data_c[c_index * c_element_size]));
store_common_to_out(result, &data_out[out_index * out_element_size]);
}
} else {
for (const auto i : c10::irange(out_numel)) {
size_t a_linear_index = i;
size_t b_linear_index = i;
size_t c_linear_index = i;

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}

auto result = compute_fun(
load_a_to_common(&data_a[a_linear_index * a_element_size]),
load_b_to_common(&data_b[b_linear_index * b_element_size]),
load_c_to_common(&data_c[c_linear_index * c_element_size]));
store_common_to_out(result, &data_out[i * out_element_size]);
}
}

Expand Down
3 changes: 3 additions & 0 deletions kernels/portable/cpu/util/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def define_common_targets():
exported_headers = [
"broadcast_util.h",
],
exported_deps = [
":broadcast_indexes_range",
],
deps = [
":repeat_util",
"//executorch/runtime/kernel:kernel_includes",
Expand Down
Loading
Loading