@@ -823,10 +823,30 @@ template <typename Func>
823
823
executorch::aten::optional<int64_t > dim,
824
824
const Tensor& out,
825
825
const Func& func) {
826
- const int64_t reduction_size = get_reduced_dim_product (in, dim);
826
+ const ssize_t reduction_size = get_reduced_dim_product (in, dim);
827
827
const auto grain_size = std::max (
828
- static_cast <int64_t >(1 ),
829
- executorch::extension::internal::GRAIN_SIZE / reduction_size);
828
+ static_cast <ssize_t >(1 ),
829
+ static_cast <ssize_t >(executorch::extension::internal::GRAIN_SIZE) /
830
+ reduction_size);
831
+ return executorch::extension::parallel_for (0 , out.numel (), grain_size, func);
832
+ }
833
+
834
+ /* *
835
+ * parallel_for wrapper for reductions that call reduce_over_dim_list or
836
+ * map_reduce_over_dim_list for each output element. Automatically
837
+ * calculates appropriate grain size.
838
+ */
839
+ template <typename Func>
840
+ [[nodiscard]] bool parallel_for_each_reduce_over_dim_list_output_index (
841
+ const Tensor& in,
842
+ optional<ArrayRef<int64_t >> dim_list,
843
+ const Tensor& out,
844
+ const Func& func) {
845
+ const ssize_t reduction_size = get_reduced_dim_product (in, dim_list);
846
+ const auto grain_size = std::max (
847
+ static_cast <ssize_t >(1 ),
848
+ static_cast <ssize_t >(executorch::extension::internal::GRAIN_SIZE) /
849
+ reduction_size);
830
850
return executorch::extension::parallel_for (0 , out.numel (), grain_size, func);
831
851
}
832
852
0 commit comments