|
12 | 12 |
|
13 | 13 | #include <executorch/kernels/portable/cpu/util/reduce_util.h>
|
14 | 14 | #include <executorch/runtime/kernel/kernel_includes.h>
|
| 15 | +#include <executorch/runtime/kernel/thread_parallel_interface.h> |
15 | 16 | #include <executorch/runtime/platform/assert.h>
|
16 | 17 |
|
17 | 18 | namespace torch {
|
@@ -47,30 +48,43 @@ Tensor& argmin_out(
|
47 | 48 | ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "argmin.out", CTYPE, [&] {
|
48 | 49 | long* out_data = out.mutable_data_ptr<long>();
|
49 | 50 |
|
50 |
| - for (const auto out_ix : c10::irange(out.numel())) { |
51 |
| - std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>( |
52 |
| - [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { |
53 |
| - // the below condition as written is equivalent to !isnan(accval) && |
54 |
| - // (isnan(v) || v < acc_val). cases: |
55 |
| - // - if neither acc_val nor v is NaN, !(v >= acc_val) is |
56 |
| - // trivially equivalent to v < acc_val. |
57 |
| - // - if acc_val is NaN, the whole thing is trivially false. |
58 |
| - // - if acc_val is not NaN and v is NaN, then v >= acc_val |
59 |
| - // - is false because all comparisons involving NaN are |
60 |
| - // - false, so the result is true. The result is trivially |
61 |
| - // - true for the above condition that uses isnan(v) as |
62 |
| - // - well. |
63 |
| - if (!std::isnan(acc_val) && !(v >= acc_val)) { |
64 |
| - acc_val = v; |
65 |
| - acc_ix = ix; |
66 |
| - } |
67 |
| - return std::tuple<CTYPE, long>{acc_val, acc_ix}; |
68 |
| - }, |
69 |
| - in, |
70 |
| - dim, |
71 |
| - out_ix); |
72 |
| - out_data[out_ix] = std::get<1>(acc); |
73 |
| - } |
| 51 | + // REVIEW: this is the parallelization strategy ATen uses |
| 52 | + // specifically when the reduction is along the last dimension and |
| 53 | + // that dimension is contiguous. Is there any particular reason we |
| 54 | + // shouldn't just always use this strategy since we aren't |
| 55 | + // otherwise capable of parallelizing reductions? |
| 56 | + const int64_t reduction_size = get_reduced_dim_product(in, dim); |
| 57 | + const auto grain_size = std::max( |
| 58 | + static_cast<int64_t>(1), |
| 59 | + executorch::extension::internal::GRAIN_SIZE / reduction_size); |
| 60 | + const bool success = executorch::extension::parallel_for( |
| 61 | + 0, out.numel(), grain_size, [&](const auto begin, const auto end) { |
| 62 | + for (const auto out_ix : c10::irange(begin, end)) { |
| 63 | + std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>( |
| 64 | + [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) { |
| 65 | + // the below condition as written is equivalent to |
| 66 | + // !isnan(accval) && (isnan(v) || v < acc_val). cases: |
| 67 | + // - if neither acc_val nor v is NaN, !(v >= acc_val) is |
| 68 | + // trivially equivalent to v < acc_val. |
| 69 | + // - if acc_val is NaN, the whole thing is trivially false. |
| 70 | + // - if acc_val is not NaN and v is NaN, then v >= acc_val |
| 71 | + // - is false because all comparisons involving NaN are |
| 72 | + // - false, so the result is true. The result is trivially |
| 73 | + // - true for the above condition that uses isnan(v) as |
| 74 | + // - well. |
| 75 | + if (!std::isnan(acc_val) && !(v >= acc_val)) { |
| 76 | + acc_val = v; |
| 77 | + acc_ix = ix; |
| 78 | + } |
| 79 | + return std::tuple<CTYPE, long>{acc_val, acc_ix}; |
| 80 | + }, |
| 81 | + in, |
| 82 | + dim, |
| 83 | + out_ix); |
| 84 | + out_data[out_ix] = std::get<1>(acc); |
| 85 | + } |
| 86 | + }); |
| 87 | + ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed"); |
74 | 88 | });
|
75 | 89 |
|
76 | 90 | return out;
|
|
0 commit comments