From 7640fe72df3080d21e8446e66650d4c1a5b97a0e Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Mon, 11 Mar 2024 13:13:33 +0000 Subject: [PATCH 1/2] declare group_pattern_util.ClusteringHelper --- paddle/cinn/frontend/group_pattern_util.cc | 26 ++++++++++++++++++++-- paddle/cinn/frontend/group_pattern_util.h | 2 +- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index 3d8890f2f6680..706ce56645fe6 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -852,12 +852,34 @@ GroupPattern FuseToGroupPattern(const std::vector& ops) { return stmt_patterns; } +class ClusteringHelper { + public: + ClusteringHelper( + const std::vector& ops, + const OpsClusteringSpec& clustering_spec) + : ops_(ops), clustering_spec_(clustering_spec){ + this->IsInThisOpList = MakePredicatorIsInThisFusionOp(ops); + this->IsInjectiveSource = + MakePredicatorIsInjectiveSource(ops_, this->IsInThisOpList); + } + + std::vector ClusterIntoGroupPatterns() { + LOG(FATAL) << "TODO(tianchao)"; + } + + private: + const std::vector ops_; + const OpsClusteringSpec clustering_spec_; + std::function IsInThisOpList; + std::function IsInjectiveSource; +}; + } // namespace std::vector ClusterIntoGroupPatternsFromOpList( const std::vector& ops, - const OpsClusteringSpec& clusteringSpec) { - // TODO(); + const OpsClusteringSpec& clustering_spec) { + return ClusteringHelper(ops, clustering_spec).ClusterIntoGroupPatterns(); } GroupPattern GenerateGroupPatternFromOpList( diff --git a/paddle/cinn/frontend/group_pattern_util.h b/paddle/cinn/frontend/group_pattern_util.h index b9785d34e99a2..5569aa644ba2d 100644 --- a/paddle/cinn/frontend/group_pattern_util.h +++ b/paddle/cinn/frontend/group_pattern_util.h @@ -29,7 +29,7 @@ struct OpsClusteringSpec { std::vector ClusterIntoGroupPatternsFromOpList( const std::vector& ops, - const OpsClusteringSpec& clusteringSpec); + const OpsClusteringSpec& clustering_spec); GroupPattern GenerateGroupPatternFromOpList( const std::vector& ops); From d3d6926eb353b063c0c8cbfc1d751d062457e0fb Mon Sep 17 00:00:00 2001 From: jiahy0825 Date: Mon, 11 Mar 2024 13:20:40 +0000 Subject: [PATCH 2/2] refine signature of group_pattern_util.ClusterIntoGroupPatternsFromOpList --- paddle/cinn/frontend/group_pattern_util.cc | 8 ++++++-- paddle/cinn/frontend/group_pattern_util.h | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index 706ce56645fe6..62927ef4c82bb 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -855,9 +855,10 @@ GroupPattern FuseToGroupPattern(const std::vector& ops) { class ClusteringHelper { public: ClusteringHelper( + const pir::ShapeConstraintIRAnalysis* shape_analysis, const std::vector& ops, const OpsClusteringSpec& clustering_spec) - : ops_(ops), clustering_spec_(clustering_spec){ + : shape_analysis_(shape_analysis), ops_(ops), clustering_spec_(clustering_spec) { this->IsInThisOpList = MakePredicatorIsInThisFusionOp(ops); this->IsInjectiveSource = MakePredicatorIsInjectiveSource(ops_, this->IsInThisOpList); @@ -868,6 +869,7 @@ class ClusteringHelper { } private: + const pir::ShapeConstraintIRAnalysis* shape_analysis_; const std::vector ops_; const OpsClusteringSpec clustering_spec_; std::function IsInThisOpList; @@ -877,9 +879,11 @@ class ClusteringHelper { } // namespace std::vector ClusterIntoGroupPatternsFromOpList( + const pir::ShapeConstraintIRAnalysis* shape_analysis, const std::vector& ops, const OpsClusteringSpec& clustering_spec) { - return ClusteringHelper(ops, clustering_spec).ClusterIntoGroupPatterns(); + ClusteringHelper helper(shape_analysis, ops, clustering_spec); + return helper.ClusterIntoGroupPatterns(); } GroupPattern GenerateGroupPatternFromOpList( diff --git a/paddle/cinn/frontend/group_pattern_util.h b/paddle/cinn/frontend/group_pattern_util.h index 5569aa644ba2d..d8183f5e80232 100644 --- a/paddle/cinn/frontend/group_pattern_util.h +++ b/paddle/cinn/frontend/group_pattern_util.h @@ -28,6 +28,7 @@ struct OpsClusteringSpec { }; std::vector ClusterIntoGroupPatternsFromOpList( + const pir::ShapeConstraintIRAnalysis* shape_analysis, const std::vector& ops, const OpsClusteringSpec& clustering_spec);