Skip to content

Commit c183ef0

Browse files
authored
parallelize optimized op_where using parallel_for (#9059)
Internal model got a 5.7% latency improvement (313.8 ms before, 296.0 ms after).
1 parent f91ebe2 commit c183ef0

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

kernels/optimized/cpu/op_where.cpp

+25-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
99
#include <executorch/runtime/kernel/kernel_includes.h>
10-
#include <iostream>
10+
#include <executorch/runtime/kernel/thread_parallel_interface.h>
1111

1212
namespace torch {
1313
namespace executor {
@@ -58,15 +58,31 @@ Tensor& opt_where_out(
5858
const bool* const data_cond = cond.const_data_ptr<bool>();
5959
CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();
6060
if (any_is_broadcasted) {
61-
for (const auto [out_index, a_index, b_index, cond_index] :
62-
BroadcastIndexesRange<3>(out, a, b, cond)) {
63-
data_out[out_index] =
64-
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
65-
}
61+
executorch::extension::parallel_for(
62+
0,
63+
out_numel,
64+
::executorch::extension::internal::GRAIN_SIZE,
65+
[&](const auto begin, const auto end) {
66+
auto range = BroadcastIndexesRange<3>(out, a, b, cond);
67+
auto begin_it = range.begin();
68+
begin_it += begin;
69+
for (; (*begin_it)[0] < end; ++begin_it) {
70+
const auto [out_index, a_index, b_index, cond_index] =
71+
*begin_it;
72+
data_out[out_index] =
73+
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
74+
}
75+
});
6676
} else {
67-
for (const auto i : c10::irange(out_numel)) {
68-
data_out[i] = data_cond[i] ? data_a[i] : data_b[i];
69-
}
77+
executorch::extension::parallel_for(
78+
0,
79+
out_numel,
80+
::executorch::extension::internal::GRAIN_SIZE,
81+
[&](const auto begin, const auto end) {
82+
for (const auto i : c10::irange(begin, end)) {
83+
data_out[i] = data_cond[i] ? data_a[i] : data_b[i];
84+
}
85+
});
7086
}
7187
});
7288
} else {

runtime/kernel/thread_parallel_interface.h

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ inline bool parallel_for_no_threadpool(
4444
return true;
4545
}
4646

47+
// Match GRAIN_SIZE from PyTorch core.
48+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/TensorIterator.h#L78
49+
constexpr int64_t GRAIN_SIZE = 32768;
4750
} // namespace internal
4851

4952
#ifdef ET_USE_THREADPOOL

0 commit comments

Comments
 (0)