1616#include " paddle/common/enforce.h"
1717#include " paddle/common/flags.h"
1818#include " paddle/pir/include/core/ir_printer.h"
19+ #include " paddle/pir/include/dialect/shape/utils/shape_analysis.h"
1920PD_DECLARE_bool (enable_cinn_compile_cache);
2021
2122namespace cinn ::hlir::framework::pir {
2223
2324constexpr static char * kOpCallStack = " op_callstack" ;
25+ constexpr static char * kSymShapeStr = " sym_shape_str" ;
2426
2527std::size_t AttributeInfo::hash () const { return attr_.hash (); }
2628
@@ -64,7 +66,8 @@ OperationInfo::OperationInfo(const ::pir::Operation& op) {
6466 attributes.begin (), attributes.end ());
6567 attr_infos_.reserve (attributes.size ());
6668 for (const auto & [attr_name, attr_value] : order_attributes) {
67- if (!attr_value || attr_name == kOpCallStack ) continue ;
69+ if (!attr_value || attr_name == kOpCallStack || attr_name == kSymShapeStr )
70+ continue ;
6871 attr_infos_.emplace_back (attr_name, attr_value);
6972 }
7073}
@@ -138,6 +141,16 @@ FusionInfo::FusionInfo(const OpLoweringGroup& group) {
138141 op_infos_.emplace_back (*op, GetInnerUpstreamOps (op));
139142 op_mapper.insert ({op, i});
140143 }
144+ auto & shape_analysis =
145+ ::pir::ShapeAnalysisManager::Instance ().Get(group.GetParentProgram());
146+ for (const auto & value : group.GetInputOpValues ()) {
147+ if (!shape_analysis.HasShapeOrDataForValue (value)) {
148+ VLOG (4 ) << " FusionInfo: input value doesn't have shape or data, skip it."
149+ << value.impl ();
150+ continue ;
151+ }
152+ input_dim_exprs_.push_back (shape_analysis.GetShapeOrDataForValue (value));
153+ }
141154}
142155
143156std::size_t FusionInfo::hash () const {
@@ -146,7 +159,9 @@ std::size_t FusionInfo::hash() const {
146159 }
147160 std::size_t seed = 2153 ;
148161 for (const auto & info : op_infos_) hash_combine (seed, info);
162+ for (const auto & dim_expr : input_dim_exprs_) hash_combine (seed, dim_expr);
149163 if (!FLAGS_enable_cinn_compile_cache) hash_combine (seed, unique_fn_name_);
164+
150165 return seed;
151166}
152167
@@ -155,34 +170,17 @@ std::ostream& operator<<(std::ostream& os, const FusionInfo& fusion_info) {
155170 if (VLOG_IS_ON (5 )) {
156171 os << " {\n " ;
157172 if (!FLAGS_enable_cinn_compile_cache)
158- os << " fn_name: " << fusion_info.unique_fn_name_ ;
173+ os << " fn_name: " << fusion_info.unique_fn_name_ << " , " ;
174+ os << " input_dim_exprs: {" ;
175+ for (const auto & dim_expr : fusion_info.input_dim_exprs_ )
176+ os << " " << dim_expr;
177+ os << " }\n " ;
159178 for (const auto & op_info : fusion_info.op_infos_ ) os << op_info << " \n " ;
160179 os << " }\n " ;
161180 }
162181 return os;
163182}
164183
165- std::size_t HashIntArgsMap (
166- const std::map<int , CINNKernelInfo::ArgDimIdx>& int_args_map) {
167- std::size_t seed = 2153 ;
168- for (const auto & [input_idx, dim_idx] : int_args_map) {
169- hash_combine (seed, input_idx);
170- hash_combine (seed, dim_idx.arg_idx );
171- hash_combine (seed, dim_idx.dim_idx );
172- }
173- return seed;
174- }
175- std::ostream& operator <<(
176- std::ostream& os,
177- const std::map<int , CINNKernelInfo::ArgDimIdx>& int_args_map) {
178- os << " int_args_map: {\n " ;
179- for (const auto & [input_idx, dim_idx] : int_args_map) {
180- os << " input_idx: " << input_idx << " :[ " << dim_idx.arg_idx << " , "
181- << dim_idx.dim_idx << " ]\n " ;
182- }
183- os << " }\n " ;
184- }
185-
186184std::vector<const ::pir::Operation*> TopologySort (
187185 const OpLoweringGroup& group) {
188186 // NOTE(Aurelius84): Use simplest one-by-one order temporaly.
0 commit comments