Skip to content

Make BroadcastIndexesRange efficient if there is no broadcasting #9298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/portable/executor_runner/executor_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <executorch/runtime/executor/method.h>
#include <executorch/runtime/executor/program.h>
#include <executorch/runtime/platform/log.h>
#include <executorch/runtime/platform/platform.h>
#include <executorch/runtime/platform/runtime.h>
#ifdef ET_EVENT_TRACER_ENABLED
#include <executorch/devtools/etdump/etdump_flatcc.h>
Expand Down Expand Up @@ -249,6 +250,7 @@ int main(int argc, char** argv) {
(uint32_t)method.error());
ET_LOG(Info, "Method loaded.");

et_timestamp_t time_spent_executing = 0;
// Run the model.
for (uint32_t i = 0; i < FLAGS_num_executions; i++) {
ET_LOG(Debug, "Preparing inputs.");
Expand All @@ -267,17 +269,24 @@ int main(int argc, char** argv) {
(uint32_t)inputs.error());
ET_LOG(Debug, "Inputs prepared.");

const et_timestamp_t before_execute = et_pal_current_ticks();
Error status = method->execute();
const et_timestamp_t after_execute = et_pal_current_ticks();
time_spent_executing += after_execute - before_execute;
ET_CHECK_MSG(
status == Error::Ok,
"Execution of method %s failed with status 0x%" PRIx32,
method_name,
(uint32_t)status);
}
const auto tick_ratio = et_pal_ticks_to_ns_multiplier();
constexpr auto NANOSECONDS_PER_MILLISECOND = 1000000;
ET_LOG(
Info,
"Model executed successfully %" PRIu32 " time(s).",
FLAGS_num_executions);
"Model executed successfully %" PRIu32 " time(s) in %f ms.",
FLAGS_num_executions,
static_cast<double>(time_spent_executing) * tick_ratio.numerator /
tick_ratio.denominator / NANOSECONDS_PER_MILLISECOND);

// Print the outputs.
std::vector<EValue> outputs(method->outputs_size());
Expand Down
46 changes: 14 additions & 32 deletions kernels/optimized/cpu/op_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,42 +48,24 @@ Tensor& opt_where_out(
cond.scalar_type() == ScalarType::Bool) {
auto out_numel = out.numel();
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes());
const bool any_is_broadcasted =
(a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted);
const CTYPE_COMPUTE* const data_a = a.const_data_ptr<CTYPE_COMPUTE>();
const CTYPE_COMPUTE* const data_b = b.const_data_ptr<CTYPE_COMPUTE>();
const bool* const data_cond = cond.const_data_ptr<bool>();
CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();
if (any_is_broadcasted) {
executorch::extension::parallel_for(
0,
out_numel,
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
auto range = BroadcastIndexesRange<3>(out, a, b, cond);
auto begin_it = range.begin();
begin_it += begin;
for (; (*begin_it)[0] < end; ++begin_it) {
const auto [out_index, a_index, b_index, cond_index] =
*begin_it;
data_out[out_index] =
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
}
});
} else {
executorch::extension::parallel_for(
0,
out_numel,
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
for (const auto i : c10::irange(begin, end)) {
data_out[i] = data_cond[i] ? data_a[i] : data_b[i];
}
});
}
executorch::extension::parallel_for(
0,
out_numel,
::executorch::extension::internal::GRAIN_SIZE,
[&](const auto begin, const auto end) {
auto range = BroadcastIndexesRange<3>(out, a, b, cond);
auto begin_it = range.begin();
begin_it += begin;
for (; (*begin_it)[0] < end; ++begin_it) {
const auto [out_index, a_index, b_index, cond_index] = *begin_it;
data_out[out_index] =
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
}
});
});
} else {
// Fall back for mixed dtype to keep code size and compile time
Expand Down
26 changes: 17 additions & 9 deletions kernels/portable/cpu/util/broadcast_indexes_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@ class BroadcastIndexesIterator {

template <typename... Args>
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
: output_dim_(output.dim()),
output_shape_(output.sizes()),
effective_input_broadcast_strides_{
effective_input_broadcast_stride(output, args)...} {
: output_dim_or_zero_if_no_broadcasting_(
((args.sizes() == output.sizes()) && ...) ? 0 : output.dim()),
output_shape_(output.sizes()) {
static_assert(
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
"arguments!");
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
effective_input_broadcast_strides_ = {
effective_input_broadcast_stride(output, args)...};
}
}

struct make_end_t {
Expand Down Expand Up @@ -73,9 +76,14 @@ class BroadcastIndexesIterator {

BroadcastIndexesIterator& operator++() {
output_index()++;
if (output_dim_or_zero_if_no_broadcasting_ == 0) {
std::fill(
current_indexes_.begin() + 1, current_indexes_.end(), output_index());
return *this;
}
// TODO: add optimization for particular input tensors not being
// broadcasted?
for (auto ii = output_dim_ - 1; ii >= 0; --ii) {
for (auto ii = output_dim_or_zero_if_no_broadcasting_ - 1; ii >= 0; --ii) {
// You might wonder what happens if output_shape_[ii] == 0. In
// that case, output.numel() would be 0, and thus we would have
// begin() == end() and no iteration.
Expand Down Expand Up @@ -121,7 +129,8 @@ class BroadcastIndexesIterator {
delinearized_output_index_.size());
for (const auto ii : c10::irange(1, kNumInputs + 1)) {
current_indexes_[ii] = 0;
for (const auto jj : c10::irange(output_dim_)) {
for (const auto jj :
c10::irange(output_dim_or_zero_if_no_broadcasting_)) {
current_indexes_[ii] += delinearized_output_index_[jj] *
effective_input_broadcast_strides_[ii - 1][jj];
}
Expand Down Expand Up @@ -180,7 +189,7 @@ class BroadcastIndexesIterator {
// followed by kNumInputs input indexes.
std::array<ssize_t, kNumInputs + 1> current_indexes_ = {0};
ShapeType delinearized_output_index_ = {0};
ssize_t output_dim_;
ssize_t output_dim_or_zero_if_no_broadcasting_;
ArrayRef<exec_aten::SizesType> output_shape_;
// The linear index for a broadcast tensor is
// sum(delinearized_output_index_[i] * input_stride_[i] if
Expand All @@ -189,8 +198,7 @@ class BroadcastIndexesIterator {
// output_dim. This is straightforwardly implementable with an
// adjusted stride array that contains 0s where the padded input
// shape would contain 1s.
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_ = {
{{0}}};
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_;
};
} // namespace internal

Expand Down
39 changes: 7 additions & 32 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,26 +254,13 @@ inline void apply_binary_elementwise_fn(
const Tensor& a,
const Tensor& b,
const Tensor& out) {
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);

const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index] :
BroadcastIndexesRange<2>(out, a, b)) {
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
size_t a_linear_index = i;
size_t b_linear_index = i;

data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
}
for (const auto [out_index, a_index, b_index] :
BroadcastIndexesRange<2>(out, a, b)) {
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
}
}

Expand All @@ -294,27 +281,15 @@ inline void apply_ternary_elementwise_fn(
const Tensor& b,
const Tensor& c,
const Tensor& out) {
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
const bool any_is_broadcasted =
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);

const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();

if (any_is_broadcasted) {
for (const auto [out_index, a_index, b_index, c_index] :
BroadcastIndexesRange<3>(out, a, b, c)) {
data_out[out_index] =
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]);
}
for (const auto [out_index, a_index, b_index, c_index] :
BroadcastIndexesRange<3>(out, a, b, c)) {
data_out[out_index] =
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
}
}

Expand Down
36 changes: 9 additions & 27 deletions kernels/portable/cpu/util/elementwise_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,6 @@ inline void apply_elementwise_fn(
internal::check_tensor_dtype(out, out_dtypes, compute_type),
InvalidArgument, );

bool any_is_broadcasted = false;
if constexpr (kNumInputs > 1) {
any_is_broadcasted = (!out.sizes().equals(inputs.first->sizes()) || ...);
}

struct InputInfo {
load_to_common_fn<CTYPE_COMMON> load_to_common;
const char* data_ptr;
Expand All @@ -99,29 +94,16 @@ inline void apply_elementwise_fn(
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
const auto out_element_size = out.element_size();

if (any_is_broadcasted) {
for (const auto& indexes :
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
for (const auto idx : c10::irange(kNumInputs)) {
const auto& input_info = inputs_info[idx];
loaded_inputs[idx] = input_info.load_to_common(
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
}
auto result = std::apply(compute_fun, loaded_inputs);
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
}
} else {
for (const auto i : c10::irange(out.numel())) {
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
for (const auto idx : c10::irange(kNumInputs)) {
const auto& input_info = inputs_info[idx];
loaded_inputs[idx] = input_info.load_to_common(
&input_info.data_ptr[i * input_info.element_size]);
}
auto result = std::apply(compute_fun, loaded_inputs);
store_common_to_out(result, &data_out[i * out_element_size]);
for (const auto& indexes :
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
for (const auto idx : c10::irange(kNumInputs)) {
const auto& input_info = inputs_info[idx];
loaded_inputs[idx] = input_info.load_to_common(
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
}
auto result = std::apply(compute_fun, loaded_inputs);
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
}
}
} // namespace internal
Expand Down
Loading