Skip to content

Commit 2d761f0

Browse files
authored
Add parallel_for_each_reduce_over_dim_list_output_index (#9143)
1 parent 91584e2 commit 2d761f0

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

kernels/portable/cpu/util/reduce_util.h

+23-3
Original file line numberDiff line numberDiff line change
@@ -823,10 +823,30 @@ template <typename Func>
823823
executorch::aten::optional<int64_t> dim,
824824
const Tensor& out,
825825
const Func& func) {
826-
const int64_t reduction_size = get_reduced_dim_product(in, dim);
826+
const ssize_t reduction_size = get_reduced_dim_product(in, dim);
827827
const auto grain_size = std::max(
828-
static_cast<int64_t>(1),
829-
executorch::extension::internal::GRAIN_SIZE / reduction_size);
828+
static_cast<ssize_t>(1),
829+
static_cast<ssize_t>(executorch::extension::internal::GRAIN_SIZE) /
830+
reduction_size);
831+
return executorch::extension::parallel_for(0, out.numel(), grain_size, func);
832+
}
833+
834+
/**
835+
* parallel_for wrapper for reductions that call reduce_over_dim_list or
836+
* map_reduce_over_dim_list for each output element. Automatically
837+
* calculates appropriate grain size.
838+
*/
839+
template <typename Func>
840+
[[nodiscard]] bool parallel_for_each_reduce_over_dim_list_output_index(
841+
const Tensor& in,
842+
optional<ArrayRef<int64_t>> dim_list,
843+
const Tensor& out,
844+
const Func& func) {
845+
const ssize_t reduction_size = get_reduced_dim_product(in, dim_list);
846+
const auto grain_size = std::max(
847+
static_cast<ssize_t>(1),
848+
static_cast<ssize_t>(executorch::extension::internal::GRAIN_SIZE) /
849+
reduction_size);
830850
return executorch::extension::parallel_for(0, out.numel(), grain_size, func);
831851
}
832852

0 commit comments

Comments
 (0)