Skip to content

Commit fed2ec8

Browse files
committed
[CINN] Add method to check applicability of GridReduce
1 parent 951eeee commit fed2ec8

File tree

2 files changed

+103
-4
lines changed

2 files changed

+103
-4
lines changed

paddle/cinn/hlir/framework/pir/trivial_op_impl.cc

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,97 @@ std::vector<int64_t> GetLoopStrides(const ir::Expr& body,
721721
return loop_strides;
722722
}
723723

724+
// Check whether we can apply grid reduce in this fusion group.
725+
// We can apply grid reduce if there is exactly one reduce, and whose result is
726+
// not broadcasted before output.
727+
bool GetCanApplyGridReduce(const std::vector<ir::Expr>& op_compute_bodies,
728+
const std::vector<int64_t>& reduce_axis) {
729+
using trivial_fusion_detail::GetAllForIters;
730+
using trivial_fusion_detail::IsReduceBody;
731+
using trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes;
732+
using trivial_fusion_detail::ExprSetFinderUtils::ChildStores;
733+
using trivial_fusion_detail::ExprSetFinderUtils::ChildTensorLoads;
734+
using trivial_fusion_detail::ExprSetFinderUtils::
735+
ScheduleBlockRealizeIsNotInit;
736+
737+
// Names of tensors that are downstream of reduce.
738+
// A tensor is downstream of reduce either if it is produced by a reduce, or
739+
// if it has data dependency on another tensor that is downstream of reduce.
740+
std::unordered_set<std::string> reduce_downstream_tensor_names;
741+
int reduce_count = 0;
742+
743+
const auto IsReduceDownstream = [&](const ir::Expr& expr_block) {
744+
for (auto& expr_load : ChildTensorLoads(expr_block)) {
745+
std::string load_tensor_name = expr_load.As<ir::Load>()->name();
746+
if (reduce_downstream_tensor_names.count(load_tensor_name) > 0) {
747+
return true;
748+
}
749+
}
750+
return false;
751+
};
752+
753+
const auto AddReduceDownstream = [&](const ir::Expr& expr_block) {
754+
auto expr_store = ChildStores.GetSingle(expr_block);
755+
std::string store_tensor_name = expr_store.As<ir::Store>()->name();
756+
reduce_downstream_tensor_names.insert(store_tensor_name);
757+
};
758+
759+
const auto CheckOutputHasReduceAxis = [&](const ir::Expr& body,
760+
const ir::Expr& expr_block) {
761+
std::vector<ir::Var> all_loop_vars = GetAllForIters(body);
762+
std::unordered_set<std::string> reduce_loop_vars;
763+
for (int64_t axis : reduce_axis) {
764+
reduce_loop_vars.insert(all_loop_vars[axis]->name);
765+
}
766+
767+
std::unordered_set<std::string> reduce_iter_vars;
768+
auto* block = expr_block.As<ir::ScheduleBlockRealize>();
769+
auto& iter_vars = block->schedule_block.As<ir::ScheduleBlock>()->iter_vars;
770+
for (int i = 0; i < iter_vars.size(); i++) {
771+
ir::Var loop_var = block->iter_values[i].as_var_ref();
772+
if (reduce_loop_vars.count(loop_var->name) > 0) {
773+
reduce_iter_vars.insert(iter_vars[i]->name);
774+
}
775+
}
776+
777+
// The result is true if the indices of the output tensor contain any
778+
// reduce iter vars.
779+
auto expr_store = ChildStores.GetSingle(expr_block);
780+
for (auto& index_expr : expr_store.As<ir::Store>()->indices) {
781+
if (reduce_iter_vars.count(index_expr.as_var_ref()->name) > 0) {
782+
return true;
783+
}
784+
}
785+
return false;
786+
};
787+
788+
for (const auto& body : op_compute_bodies) {
789+
ir::Expr expr_block =
790+
(ChildScheduleBlockRealizes * ScheduleBlockRealizeIsNotInit)
791+
.GetSingle(body);
792+
bool is_reduce_body = IsReduceBody(body);
793+
bool is_reduce_downstream = IsReduceDownstream(expr_block);
794+
bool output_has_reduce_axis = CheckOutputHasReduceAxis(body, expr_block);
795+
796+
if (is_reduce_body) {
797+
++reduce_count;
798+
}
799+
if (is_reduce_downstream || is_reduce_body) {
800+
AddReduceDownstream(expr_block);
801+
}
802+
803+
// When a block is downstream of reduce, its output shouldn't contain
804+
// reduce axis. Otherwise, it broadcasts the result of reduce. If this
805+
// is the case, we cannot apply grid reduce.
806+
if (is_reduce_downstream && output_has_reduce_axis) {
807+
VLOG(4) << "grid reduce is prohibited by block: " << expr_block;
808+
return false;
809+
}
810+
}
811+
812+
return reduce_count == 1;
813+
}
814+
724815
std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
725816
const std::vector<ir::Expr>& op_compute_bodies) {
726817
using trivial_fusion_detail::AppendBound;
@@ -792,6 +883,10 @@ std::shared_ptr<FusionGroupInfo> GetFusionGroupInfo(
792883
}
793884
});
794885
}
886+
887+
group_info->can_apply_grid_reduce =
888+
GetCanApplyGridReduce(op_compute_bodies, group_info->reduce_axis);
889+
795890
VLOG(4) << group_info->DebugPrint();
796891
return group_info;
797892
}

paddle/cinn/hlir/framework/pir/trivial_op_impl.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,16 @@ struct FusionGroupInfo {
168168
std::vector<int64_t> loop_strides;
169169
std::vector<int64_t> reduce_axis;
170170
std::vector<std::string> reduce_var_name;
171+
bool can_apply_grid_reduce;
171172

172173
std::string DebugPrint() {
173-
return "GroupInfo\nloop_ranges: " + cinn::utils::Join(loop_ranges, " ") +
174-
"\nloop_strides: " + cinn::utils::Join(loop_strides, ", ") +
175-
"\nreduce_axis: " + cinn::utils::Join(reduce_axis, " ") +
176-
"\nreduce_var_name: " + cinn::utils::Join(reduce_var_name, " ");
174+
std::stringstream ss;
175+
ss << "GroupInfo\nloop_ranges: " << cinn::utils::Join(loop_ranges, " ")
176+
<< "\nloop_strides: " << cinn::utils::Join(loop_strides, ", ")
177+
<< "\nreduce_axis: " << cinn::utils::Join(reduce_axis, " ")
178+
<< "\nreduce_var_name: " << cinn::utils::Join(reduce_var_name, " ")
179+
<< "\ncan_apply_grid_reduce: " << can_apply_grid_reduce;
180+
return ss.str();
177181
}
178182
};
179183

0 commit comments

Comments
 (0)