Skip to content

Commit 243c885

Browse files
authored
Proof-of-concept: parallelize argmin (#9066)
I attempted to port `at::parallel_reduce` to ExecuTorch and use that in reduce_util.h, but it turned out to be much trickier than expected. (In brief: parallel reduction requires two steps: 1) split the input range into chunks and reduce over them (easily done like parallel_for), and then 2) combine the sub-results from chunks. The reduction function accepted by reduce_over_dim is not well-suited to step (2).) Instead, I ported the parallelization strategy used by binary_kernel_reduce_lastdim: just parallelize over the *non*-reduced dimensions of the tensor. I don't understand why this strategy isn't generally applicable and we aren't otherwise capable of parallelizing reductions, so I haven't gated it to the case where we are reducing over a contiguous last dimension. I will send a follow-up that packages up this strategy nicely and uses it in our reduction portable ops.
1 parent 907e97e commit 243c885

File tree

2 files changed

+39
-24
lines changed

2 files changed

+39
-24
lines changed

kernels/portable/cpu/op_argmin.cpp

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1414
#include <executorch/runtime/kernel/kernel_includes.h>
15+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1516
#include <executorch/runtime/platform/assert.h>
1617

1718
namespace torch {
@@ -47,30 +48,43 @@ Tensor& argmin_out(
4748
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
4849
long* out_data = out.mutable_data_ptr<long>();
4950

50-
for (const auto out_ix : c10::irange(out.numel())) {
51-
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
52-
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
53-
// the below condition as written is equivalent to !isnan(accval) &&
54-
// (isnan(v) || v < acc_val). cases:
55-
// - if neither acc_val nor v is NaN, !(v >= acc_val) is
56-
// trivially equivalent to v < acc_val.
57-
// - if acc_val is NaN, the whole thing is trivially false.
58-
// - if acc_val is not NaN and v is NaN, then v >= acc_val
59-
// - is false because all comparisons involving NaN are
60-
// - false, so the result is true. The result is trivially
61-
// - true for the above condition that uses isnan(v) as
62-
// - well.
63-
if (!std::isnan(acc_val) && !(v >= acc_val)) {
64-
acc_val = v;
65-
acc_ix = ix;
66-
}
67-
return std::tuple<CTYPE, long>{acc_val, acc_ix};
68-
},
69-
in,
70-
dim,
71-
out_ix);
72-
out_data[out_ix] = std::get<1>(acc);
73-
}
51+
// REVIEW: this is the parallelization strategy ATen uses
52+
// specifically when the reduction is along the last dimension and
53+
// that dimension is contiguous. Is there any particular reason we
54+
// shouldn't just always use this strategy since we aren't
55+
// otherwise capable of parallelizing reductions?
56+
const int64_t reduction_size = get_reduced_dim_product(in, dim);
57+
const auto grain_size = std::max(
58+
static_cast<int64_t>(1),
59+
executorch::extension::internal::GRAIN_SIZE / reduction_size);
60+
const bool success = executorch::extension::parallel_for(
61+
0, out.numel(), grain_size, [&](const auto begin, const auto end) {
62+
for (const auto out_ix : c10::irange(begin, end)) {
63+
std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
64+
[](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
65+
// the below condition as written is equivalent to
66+
// !isnan(accval) && (isnan(v) || v < acc_val). cases:
67+
// - if neither acc_val nor v is NaN, !(v >= acc_val) is
68+
// trivially equivalent to v < acc_val.
69+
// - if acc_val is NaN, the whole thing is trivially false.
70+
// - if acc_val is not NaN and v is NaN, then v >= acc_val
71+
// - is false because all comparisons involving NaN are
72+
// - false, so the result is true. The result is trivially
73+
// - true for the above condition that uses isnan(v) as
74+
// - well.
75+
if (!std::isnan(acc_val) && !(v >= acc_val)) {
76+
acc_val = v;
77+
acc_ix = ix;
78+
}
79+
return std::tuple<CTYPE, long>{acc_val, acc_ix};
80+
},
81+
in,
82+
dim,
83+
out_ix);
84+
out_data[out_ix] = std::get<1>(acc);
85+
}
86+
});
87+
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
7488
});
7589

7690
return out;

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ ATEN_OPS = (
284284
name = "op_argmin",
285285
deps = [
286286
"//executorch/kernels/portable/cpu/util:reduce_util",
287+
"//executorch/runtime/kernel:thread_parallel_interface",
287288
],
288289
),
289290
op_target(

0 commit comments

Comments
 (0)