Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#55 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
refine signature of group_pattern_util.ClusterIntoGroupPatternsFromOpList
  • Loading branch information
tc20042008 committed Mar 11, 2024
2 parents 5ff4943 + d3d6926 commit c0dd054
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
30 changes: 28 additions & 2 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,12 +852,38 @@ GroupPattern FuseToGroupPattern(const std::vector<pir::Operation*>& ops) {
return stmt_patterns;
}

class ClusteringHelper {
public:
ClusteringHelper(
const pir::ShapeConstraintIRAnalysis* shape_analysis,
const std::vector<pir::Operation*>& ops,
const OpsClusteringSpec& clustering_spec)
: shape_analysis_(shape_analysis), ops_(ops), clustering_spec_(clustering_spec) {
this->IsInThisOpList = MakePredicatorIsInThisFusionOp(ops);
this->IsInjectiveSource =
MakePredicatorIsInjectiveSource(ops_, this->IsInThisOpList);
}

std::vector<ConditionalGroupPattern> ClusterIntoGroupPatterns() {
LOG(FATAL) << "TODO(tianchao)";
}

private:
const pir::ShapeConstraintIRAnalysis* shape_analysis_;
const std::vector<pir::Operation*> ops_;
const OpsClusteringSpec clustering_spec_;
std::function<bool(const pir::Operation*)> IsInThisOpList;
std::function<bool(const pir::Operation*)> IsInjectiveSource;
};

} // namespace

std::vector<ConditionalGroupPattern> ClusterIntoGroupPatternsFromOpList(
const pir::ShapeConstraintIRAnalysis* shape_analysis,
const std::vector<pir::Operation*>& ops,
const OpsClusteringSpec& clusteringSpec) {
// TODO();
const OpsClusteringSpec& clustering_spec) {
ClusteringHelper helper(shape_analysis, ops, clustering_spec);
return helper.ClusterIntoGroupPatterns();
}

GroupPattern GenerateGroupPatternFromOpList(
Expand Down
3 changes: 2 additions & 1 deletion paddle/cinn/frontend/group_pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ struct OpsClusteringSpec {
};

std::vector<ConditionalGroupPattern> ClusterIntoGroupPatternsFromOpList(
const pir::ShapeConstraintIRAnalysis* shape_analysis,
const std::vector<pir::Operation*>& ops,
const OpsClusteringSpec& clusteringSpec);
const OpsClusteringSpec& clustering_spec);

GroupPattern GenerateGroupPatternFromOpList(
const std::vector<pir::Operation*>& ops);
Expand Down

0 comments on commit c0dd054

Please sign in to comment.