|
7 | 7 | */
|
8 | 8 | #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
|
9 | 9 | #include <executorch/runtime/kernel/kernel_includes.h>
|
10 |
| -#include <iostream> |
| 10 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
11 | 11 |
|
12 | 12 | namespace torch {
|
13 | 13 | namespace executor {
|
@@ -58,15 +58,31 @@ Tensor& opt_where_out(
|
58 | 58 | const bool* const data_cond = cond.const_data_ptr<bool>();
|
59 | 59 | CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();
|
60 | 60 | if (any_is_broadcasted) {
|
61 |
| - for (const auto [out_index, a_index, b_index, cond_index] : |
62 |
| - BroadcastIndexesRange<3>(out, a, b, cond)) { |
63 |
| - data_out[out_index] = |
64 |
| - data_cond[cond_index] ? data_a[a_index] : data_b[b_index]; |
65 |
| - } |
| 61 | + executorch::extension::parallel_for( |
| 62 | + 0, |
| 63 | + out_numel, |
| 64 | + ::executorch::extension::internal::GRAIN_SIZE, |
| 65 | + [&](const auto begin, const auto end) { |
| 66 | + auto range = BroadcastIndexesRange<3>(out, a, b, cond); |
| 67 | + auto begin_it = range.begin(); |
| 68 | + begin_it += begin; |
| 69 | + for (; (*begin_it)[0] < end; ++begin_it) { |
| 70 | + const auto [out_index, a_index, b_index, cond_index] = |
| 71 | + *begin_it; |
| 72 | + data_out[out_index] = |
| 73 | + data_cond[cond_index] ? data_a[a_index] : data_b[b_index]; |
| 74 | + } |
| 75 | + }); |
66 | 76 | } else {
|
67 |
| - for (const auto i : c10::irange(out_numel)) { |
68 |
| - data_out[i] = data_cond[i] ? data_a[i] : data_b[i]; |
69 |
| - } |
| 77 | + executorch::extension::parallel_for( |
| 78 | + 0, |
| 79 | + out_numel, |
| 80 | + ::executorch::extension::internal::GRAIN_SIZE, |
| 81 | + [&](const auto begin, const auto end) { |
| 82 | + for (const auto i : c10::irange(begin, end)) { |
| 83 | + data_out[i] = data_cond[i] ? data_a[i] : data_b[i]; |
| 84 | + } |
| 85 | + }); |
70 | 86 | }
|
71 | 87 | });
|
72 | 88 | } else {
|
|
0 commit comments