@@ -721,6 +721,97 @@ std::vector<int64_t> GetLoopStrides(const ir::Expr& body,
721
721
return loop_strides;
722
722
}
723
723
724
+ // Check whether we can apply grid reduce in this fusion group.
725
+ // We can apply grid reduce if there is exactly one reduce, and whose result is
726
+ // not broadcasted before output.
727
+ bool GetCanApplyGridReduce (const std::vector<ir::Expr>& op_compute_bodies,
728
+ const std::vector<int64_t >& reduce_axis) {
729
+ using trivial_fusion_detail::GetAllForIters;
730
+ using trivial_fusion_detail::IsReduceBody;
731
+ using trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes;
732
+ using trivial_fusion_detail::ExprSetFinderUtils::ChildStores;
733
+ using trivial_fusion_detail::ExprSetFinderUtils::ChildTensorLoads;
734
+ using trivial_fusion_detail::ExprSetFinderUtils::
735
+ ScheduleBlockRealizeIsNotInit;
736
+
737
+ // Names of tensors that are downstream of reduce.
738
+ // A tensor is downstream of reduce either if it is produced by a reduce, or
739
+ // if it has data dependency on another tensor that is downstream of reduce.
740
+ std::unordered_set<std::string> reduce_downstream_tensor_names;
741
+ int reduce_count = 0 ;
742
+
743
+ const auto IsReduceDownstream = [&](const ir::Expr& expr_block) {
744
+ for (auto & expr_load : ChildTensorLoads (expr_block)) {
745
+ std::string load_tensor_name = expr_load.As <ir::Load>()->name ();
746
+ if (reduce_downstream_tensor_names.count (load_tensor_name) > 0 ) {
747
+ return true ;
748
+ }
749
+ }
750
+ return false ;
751
+ };
752
+
753
+ const auto AddReduceDownstream = [&](const ir::Expr& expr_block) {
754
+ auto expr_store = ChildStores.GetSingle (expr_block);
755
+ std::string store_tensor_name = expr_store.As <ir::Store>()->name ();
756
+ reduce_downstream_tensor_names.insert (store_tensor_name);
757
+ };
758
+
759
+ const auto CheckOutputHasReduceAxis = [&](const ir::Expr& body,
760
+ const ir::Expr& expr_block) {
761
+ std::vector<ir::Var> all_loop_vars = GetAllForIters (body);
762
+ std::unordered_set<std::string> reduce_loop_vars;
763
+ for (int64_t axis : reduce_axis) {
764
+ reduce_loop_vars.insert (all_loop_vars[axis]->name );
765
+ }
766
+
767
+ std::unordered_set<std::string> reduce_iter_vars;
768
+ auto * block = expr_block.As <ir::ScheduleBlockRealize>();
769
+ auto & iter_vars = block->schedule_block .As <ir::ScheduleBlock>()->iter_vars ;
770
+ for (int i = 0 ; i < iter_vars.size (); i++) {
771
+ ir::Var loop_var = block->iter_values [i].as_var_ref ();
772
+ if (reduce_loop_vars.count (loop_var->name ) > 0 ) {
773
+ reduce_iter_vars.insert (iter_vars[i]->name );
774
+ }
775
+ }
776
+
777
+ // The result is true if the indices of the output tensor contain any
778
+ // reduce iter vars.
779
+ auto expr_store = ChildStores.GetSingle (expr_block);
780
+ for (auto & index_expr : expr_store.As <ir::Store>()->indices ) {
781
+ if (reduce_iter_vars.count (index_expr.as_var_ref ()->name ) > 0 ) {
782
+ return true ;
783
+ }
784
+ }
785
+ return false ;
786
+ };
787
+
788
+ for (const auto & body : op_compute_bodies) {
789
+ ir::Expr expr_block =
790
+ (ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit)
791
+ .GetSingle (body);
792
+ bool is_reduce_body = IsReduceBody (body);
793
+ bool is_reduce_downstream = IsReduceDownstream (expr_block);
794
+ bool output_has_reduce_axis = CheckOutputHasReduceAxis (body, expr_block);
795
+
796
+ if (is_reduce_body) {
797
+ ++reduce_count;
798
+ }
799
+ if (is_reduce_downstream || is_reduce_body) {
800
+ AddReduceDownstream (expr_block);
801
+ }
802
+
803
+ // When a block is downstream of reduce, its output shouldn't contain
804
+ // reduce axis. Otherwise, it broadcasts the result of reduce. If this
805
+ // is the case, we cannot apply grid reduce.
806
+ if (is_reduce_downstream && output_has_reduce_axis) {
807
+ VLOG (4 ) << " grid reduce is prohibited by block: " << expr_block;
808
+ return false ;
809
+ }
810
+ }
811
+
812
+ return reduce_count == 1 ;
813
+ }
814
+
724
815
std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo (
725
816
const std::vector<ir::Expr>& op_compute_bodies) {
726
817
using trivial_fusion_detail::AppendBound;
@@ -792,6 +883,10 @@ std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
792
883
}
793
884
});
794
885
}
886
+
887
+ group_info->can_apply_grid_reduce =
888
+ GetCanApplyGridReduce (op_compute_bodies, group_info->reduce_axis );
889
+
795
890
VLOG (4 ) << group_info->DebugPrint ();
796
891
return group_info;
797
892
}
0 commit comments