1414
1515#include " paddle/cinn/hlir/framework/pir/fusion_info.h"
1616#include " paddle/common/enforce.h"
17+ #include " paddle/common/flags.h"
1718#include " paddle/pir/include/core/ir_printer.h"
19+ PD_DECLARE_bool (enable_cinn_compile_cache);
1820
1921namespace cinn ::hlir::framework::pir {
2022
@@ -46,10 +48,12 @@ std::ostream& operator<<(std::ostream& os, const ValueInfo& value_info) {
4648
4749OperationInfo::OperationInfo (const ::pir::Operation& op) {
4850 name_ = op.name ();
51+ input_infos_.reserve (op.num_operands ());
4952 for (const auto value : op.operands_source ()) {
5053 if (!value || !value.type ()) continue ;
5154 input_infos_.emplace_back (value);
5255 }
56+ output_infos_.reserve (op.num_results ());
5357 for (const auto value : op.results ()) {
5458 if (!value || !value.type ()) continue ;
5559 output_infos_.emplace_back (value);
@@ -58,6 +62,7 @@ OperationInfo::OperationInfo(const ::pir::Operation& op) {
5862 const auto & attributes = op.attributes ();
5963 std::map<std::string, ::pir::Attribute, std::less<>> order_attributes (
6064 attributes.begin (), attributes.end ());
65+ attr_infos_.reserve (attributes.size ());
6166 for (const auto & [attr_name, attr_value] : order_attributes) {
6267 if (!attr_value || attr_name == kOpCallStack ) continue ;
6368 attr_infos_.emplace_back (attr_name, attr_value);
@@ -85,9 +90,53 @@ std::ostream& operator<<(std::ostream& os, const OperationInfo& op_info) {
8590 return os;
8691}
8792
93+ std::size_t FusionOpInfo::hash () const {
94+ std::size_t seed = op_info_.hash ();
95+ for (const auto & [value_index, op_info_hash] : inner_deps_) {
96+ hash_combine (seed, value_index);
97+ hash_combine (seed, op_info_hash);
98+ }
99+ return seed;
100+ }
101+
102+ std::ostream& operator <<(std::ostream& os, const FusionOpInfo& info) {
103+ os << info.op_info_ << " , inner_deps:{" ;
104+ for (const auto & [value_index, op_info_hash] : info.inner_deps_ ) {
105+ os << " (" << value_index << " , " << op_info_hash << " )" ;
106+ }
107+ os << " }" ;
108+ return os;
109+ }
110+
88111FusionInfo::FusionInfo (const OpLoweringGroup& group) {
89- for (const auto * op : TopologySort (group)) {
90- op_infos_.emplace_back (*op);
112+ std::unordered_map<const ::pir::Operation*, size_t > op_mapper;
113+ unique_fn_name_ = group.FuncName ();
114+
115+ const auto GetInnerUpstreamOps =
116+ [&](const ::pir::Operation* op) -> decltype (auto ) {
117+ std::unordered_map<size_t , size_t > upstream_ops_index_hash;
118+ for (size_t i = 0 ; i < op->num_operands (); ++i) {
119+ const auto value = op->operand_source (i);
120+ if (!value || !value.defining_op ()) continue ;
121+ const auto * defining_op = value.defining_op ();
122+ if (op_mapper.count (defining_op) == 0 ) continue ;
123+ PADDLE_ENFORCE_LT (op_mapper[defining_op],
124+ this ->op_infos_ .size (),
125+ ::common::errors::OutOfRange (
126+ " Required op_mapper[defining_op] < "
127+ " op_infos_.size(), but received index %d" ,
128+ op_mapper[defining_op]));
129+ upstream_ops_index_hash.emplace (
130+ i, this ->op_infos_ [op_mapper[defining_op]].hash ());
131+ }
132+ return upstream_ops_index_hash;
133+ };
134+
135+ const auto sorted_ops = TopologySort (group);
136+ for (size_t i = 0 ; i < sorted_ops.size (); ++i) {
137+ const auto & op = sorted_ops[i];
138+ op_infos_.emplace_back (*op, GetInnerUpstreamOps (op));
139+ op_mapper.insert ({op, i});
91140 }
92141}
93142
@@ -97,13 +146,16 @@ std::size_t FusionInfo::hash() const {
97146 }
98147 std::size_t seed = 2153 ;
99148 for (const auto & info : op_infos_) hash_combine (seed, info);
149+ if (!FLAGS_enable_cinn_compile_cache) hash_combine (seed, unique_fn_name_);
100150 return seed;
101151}
102152
103153std::ostream& operator <<(std::ostream& os, const FusionInfo& fusion_info) {
104154 os << " FusionInfo - " << fusion_info.hash ();
105155 if (VLOG_IS_ON (5 )) {
106156 os << " {\n " ;
157+ if (!FLAGS_enable_cinn_compile_cache)
158+ os << " fn_name: " << fusion_info.unique_fn_name_ ;
107159 for (const auto & op_info : fusion_info.op_infos_ ) os << op_info << " \n " ;
108160 os << " }\n " ;
109161 }
0 commit comments