Skip to content

Commit 6c66abe

Browse files
authored
[PIR+CINN]Consider ShapeAnalysis SymDimExprs instead of sym_expr_str attribute as Hash element (#63636)
* merge dev * fix ostringstream * diable compilation cache in pre_analysis * fix typo * fix UT
1 parent f9dff23 commit 6c66abe

File tree

12 files changed

+81
-54
lines changed

12 files changed

+81
-54
lines changed

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
1717
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
1818
#include "paddle/cinn/hlir/framework/pir_compiler.h"
19+
#include "paddle/common/flags.h"
20+
21+
PD_DECLARE_bool(enable_cinn_compile_cache);
1922

2023
namespace cinn::dialect::ir::details {
2124
using cinn::hlir::framework::PirCompiler;
@@ -42,6 +45,10 @@ void FusionOpAnalysis::RunImpl(pir::Operation* op) {
4245
}
4346

4447
void FusionOpAnalysis::PreCompileGroup() {
48+
// Make compilation into lazy mode while
49+
// FLAGS_enable_cinn_compile_cache=false.
50+
if (!FLAGS_enable_cinn_compile_cache) return;
51+
4552
std::vector<OpLoweringGroupPtr> groups;
4653
for (auto& group_info : *group_infos_) {
4754
if (is_dy_shape_ && NeedBroadcastWithCF(group_info.second)) continue;

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
2222
#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h"
2323
#include "paddle/cinn/hlir/framework/pir/compilation_cache.h"
24+
#include "paddle/cinn/hlir/framework/pir/utils.h"
2425
#include "paddle/cinn/hlir/framework/pir_compiler.h"
2526
#include "paddle/cinn/runtime/flags.h"
2627

2728
PD_DECLARE_bool(cinn_enable_map_expr);
29+
PD_DECLARE_bool(enable_cinn_compile_cache);
2830

2931
namespace cinn::dialect::ir::details {
3032

@@ -78,12 +80,19 @@ CompileGroupAsOpAttribute(const std::vector<OpLoweringGroupPtr>& group_list) {
7880

7981
std::unordered_map<std::string, ::pir::Attribute> GetJitKernelAttr(
8082
const OpLoweringGroupPtr& group) {
81-
hlir::framework::pir::FusionInfo fusion_info(*group);
82-
auto kernel_info = CompilationCache::Instance().GetKernelInfo(fusion_info);
83+
const auto CreateKernelInfo = [&]() -> hlir::framework::pir::CINNKernelInfo {
84+
if (FLAGS_enable_cinn_compile_cache) {
85+
hlir::framework::pir::FusionInfo fusion_info(*group);
86+
return CompilationCache::Instance().GetKernelInfo(fusion_info);
87+
} else {
88+
PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget());
89+
return pir_compiler.Build({group})[0];
90+
}
91+
};
8392
std::unordered_map<std::string, ::pir::Attribute> attrs{
8493
{cinn::dialect::JitKernelOp::kAttrName,
8594
cinn::dialect::CINNKernelInfoAttribute::get(pir::IrContext::Instance(),
86-
kernel_info)}};
95+
CreateKernelInfo())}};
8796
return attrs;
8897
}
8998

paddle/cinn/hlir/framework/pir/compilation_cache.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
#include "paddle/cinn/hlir/framework/pir/compilation_cache.h"
1616
#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h"
1717

18-
#include "paddle/common/enforce.h"
19-
2018
namespace cinn::hlir::framework {
2119

2220
namespace pir {

paddle/cinn/hlir/framework/pir/compilation_cache.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "paddle/cinn/common/target.h"
2222
#include "paddle/cinn/hlir/framework/pir/fusion_info.h"
2323
#include "paddle/cinn/hlir/framework/pir/utils.h"
24+
#include "paddle/common/enforce.h"
2425

2526
namespace cinn::hlir::framework {
2627

@@ -68,7 +69,10 @@ class CompilationResult final {
6869
}
6970

7071
pir::CINNKernelInfo GetKernelInfo() {
71-
// TODO(Aurelius84): add ENFORCE_NOT_NULL
72+
PADDLE_ENFORCE_NOT_NULL(backend_resource_,
73+
::common::errors::PreconditionNotMet(
74+
"Found backend_resource_ is nullptr, please "
75+
"call SetBackendResource first."));
7276
return backend_resource_->GenerateKernelInfo();
7377
}
7478

paddle/cinn/hlir/framework/pir/fusion_info.cc

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
#include "paddle/common/enforce.h"
1717
#include "paddle/common/flags.h"
1818
#include "paddle/pir/include/core/ir_printer.h"
19+
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
1920
PD_DECLARE_bool(enable_cinn_compile_cache);
2021

2122
namespace cinn::hlir::framework::pir {
2223

2324
constexpr static char* kOpCallStack = "op_callstack";
25+
constexpr static char* kSymShapeStr = "sym_shape_str";
2426

2527
std::size_t AttributeInfo::hash() const { return attr_.hash(); }
2628

@@ -64,7 +66,8 @@ OperationInfo::OperationInfo(const ::pir::Operation& op) {
6466
attributes.begin(), attributes.end());
6567
attr_infos_.reserve(attributes.size());
6668
for (const auto& [attr_name, attr_value] : order_attributes) {
67-
if (!attr_value || attr_name == kOpCallStack) continue;
69+
if (!attr_value || attr_name == kOpCallStack || attr_name == kSymShapeStr)
70+
continue;
6871
attr_infos_.emplace_back(attr_name, attr_value);
6972
}
7073
}
@@ -138,6 +141,16 @@ FusionInfo::FusionInfo(const OpLoweringGroup& group) {
138141
op_infos_.emplace_back(*op, GetInnerUpstreamOps(op));
139142
op_mapper.insert({op, i});
140143
}
144+
auto& shape_analysis =
145+
::pir::ShapeAnalysisManager::Instance().Get(group.GetParentProgram());
146+
for (const auto& value : group.GetInputOpValues()) {
147+
if (!shape_analysis.HasShapeOrDataForValue(value)) {
148+
VLOG(4) << "FusionInfo: input value doesn't have shape or data, skip it."
149+
<< value.impl();
150+
continue;
151+
}
152+
input_dim_exprs_.push_back(shape_analysis.GetShapeOrDataForValue(value));
153+
}
141154
}
142155

143156
std::size_t FusionInfo::hash() const {
@@ -146,7 +159,9 @@ std::size_t FusionInfo::hash() const {
146159
}
147160
std::size_t seed = 2153;
148161
for (const auto& info : op_infos_) hash_combine(seed, info);
162+
for (const auto& dim_expr : input_dim_exprs_) hash_combine(seed, dim_expr);
149163
if (!FLAGS_enable_cinn_compile_cache) hash_combine(seed, unique_fn_name_);
164+
150165
return seed;
151166
}
152167

@@ -155,34 +170,17 @@ std::ostream& operator<<(std::ostream& os, const FusionInfo& fusion_info) {
155170
if (VLOG_IS_ON(5)) {
156171
os << "{\n";
157172
if (!FLAGS_enable_cinn_compile_cache)
158-
os << "fn_name: " << fusion_info.unique_fn_name_;
173+
os << "fn_name: " << fusion_info.unique_fn_name_ << ", ";
174+
os << "input_dim_exprs: {";
175+
for (const auto& dim_expr : fusion_info.input_dim_exprs_)
176+
os << " " << dim_expr;
177+
os << " }\n";
159178
for (const auto& op_info : fusion_info.op_infos_) os << op_info << "\n";
160179
os << "}\n";
161180
}
162181
return os;
163182
}
164183

165-
std::size_t HashIntArgsMap(
166-
const std::map<int, CINNKernelInfo::ArgDimIdx>& int_args_map) {
167-
std::size_t seed = 2153;
168-
for (const auto& [input_idx, dim_idx] : int_args_map) {
169-
hash_combine(seed, input_idx);
170-
hash_combine(seed, dim_idx.arg_idx);
171-
hash_combine(seed, dim_idx.dim_idx);
172-
}
173-
return seed;
174-
}
175-
std::ostream& operator<<(
176-
std::ostream& os,
177-
const std::map<int, CINNKernelInfo::ArgDimIdx>& int_args_map) {
178-
os << "int_args_map: {\n";
179-
for (const auto& [input_idx, dim_idx] : int_args_map) {
180-
os << "input_idx: " << input_idx << ":[ " << dim_idx.arg_idx << ", "
181-
<< dim_idx.dim_idx << " ]\n";
182-
}
183-
os << "}\n";
184-
}
185-
186184
std::vector<const ::pir::Operation*> TopologySort(
187185
const OpLoweringGroup& group) {
188186
// NOTE(Aurelius84): Use simplest one-by-one order temporaly.

paddle/cinn/hlir/framework/pir/fusion_info.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#pragma once
1616
#include <ostream>
1717
#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h"
18+
#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h"
1819

1920
namespace cinn::hlir::framework::pir {
2021

@@ -90,6 +91,7 @@ class FusionInfo {
9091

9192
private:
9293
std::vector<FusionOpInfo> op_infos_;
94+
std::vector<::symbol::ShapeOrDataDimExprs> input_dim_exprs_;
9395
std::size_t cached_hash_value_{0};
9496

9597
// Used to make same subgraphs have unique FusionInfo while
@@ -111,11 +113,6 @@ inline void hash_combine(std::size_t &seed, // NOLINT
111113
seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
112114
}
113115

114-
std::size_t HashIntArgsMap(
115-
const std::map<int, CINNKernelInfo::ArgDimIdx> &int_args_map);
116-
std::ostream &operator<<(
117-
std::ostream &os,
118-
const std::map<int, CINNKernelInfo::ArgDimIdx> &int_args_map);
119116
std::vector<const ::pir::Operation *> TopologySort(
120117
const OpLoweringGroup &group);
121118

paddle/cinn/hlir/framework/pir/op_lowering_group.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,21 @@ std::vector<::pir::Value> OpLoweringGroup::GetGroupOutputValues() const {
7474
return output_values;
7575
}
7676

77-
std::unordered_set<::pir::Value> OpLoweringGroup::GetInputOpValues() const {
78-
std::unordered_set<::pir::Value> group_inputs;
77+
std::vector<::pir::Value> OpLoweringGroup::GetInputOpValues() const {
78+
std::unordered_set<::pir::Value> visited_values;
79+
std::vector<::pir::Value> group_inputs;
7980
std::unordered_set<::pir::Operation*> ops_set(this->ops_.begin(),
8081
this->ops_.end());
8182

8283
// count all op's input Value
83-
for (auto op : ops_set) {
84+
for (auto op : ops_) {
8485
for (auto& value : op->operands_source()) {
8586
if (!value || !value.type() || ops_set.count(value.defining_op()))
8687
continue;
88+
if (visited_values.count(value)) continue;
8789
// if the input value owner op is not in OpSet, it's the group's input
88-
group_inputs.insert(value);
90+
visited_values.insert(value);
91+
group_inputs.push_back(value);
8992
}
9093
}
9194
return group_inputs;

paddle/cinn/hlir/framework/pir/op_lowering_group.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class OpLoweringGroup {
5656
::pir::Block* GetParentBlock() const;
5757
::pir::Program* GetParentProgram() const;
5858
std::vector<::pir::Value> GetGroupOutputValues() const;
59-
std::unordered_set<::pir::Value> GetInputOpValues() const;
59+
std::vector<::pir::Value> GetInputOpValues() const;
6060
std::unordered_set<::pir::Value> GetOutputOpValues() const;
6161
const symbol::ShapeOrDataDimExprs& GetShapeOrDataExprs(
6262
const ::pir::Value& value) const;

paddle/cinn/hlir/framework/pir_compiler.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,8 @@ void CompilationContextMapper::UpdateGlobalCache() {
137137
::common::errors::PreconditionNotMet(
138138
"Required mapper_index < fusion_infos_.size()."));
139139
const auto& fusion_info = fusion_infos_[mapper_index_[i]];
140-
const auto& int_args_map =
141-
compilation_results_[i]->GetBackendResource()->GetIntArgsMap();
142140
VLOG(5) << "Insert new compiled result into cache, fusion_info: "
143-
<< fusion_info << ", int_args_map: " << int_args_map;
141+
<< fusion_info;
144142
CompilationCache::Instance().Insert(fusion_info, compilation_results_[i]);
145143
}
146144
}

paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414

1515
#pragma once
16-
16+
#include <sstream>
1717
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"
1818
#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h"
1919

@@ -172,4 +172,15 @@ IR_API ShapeOrDataDimExprs SubstituteShapeOrData(
172172

173173
IR_API std::ostream& operator<<(std::ostream&,
174174
const ShapeOrDataDimExprs& dim_expr);
175+
175176
} // namespace symbol
177+
namespace std {
178+
template <>
179+
struct hash<symbol::ShapeOrDataDimExprs> {
180+
std::size_t operator()(const symbol::ShapeOrDataDimExprs& obj) const {
181+
std::ostringstream os;
182+
os << obj;
183+
return std::hash<std::string>()(os.str());
184+
}
185+
};
186+
} // namespace std

0 commit comments

Comments
 (0)