Skip to content

Commit 8d530db

Browse files
authored
[PIR+CINN]Add FusionOpInfo to enhance CompilationCache logic (#63615)
* [PIR+CINN]Add FusionOpInfo to enhance CompilationCache logic * fix UT
1 parent 5152fce commit 8d530db

File tree

8 files changed

+222
-9
lines changed

8 files changed

+222
-9
lines changed

paddle/cinn/hlir/framework/pir/compilation_cache.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void* BackendResource::GetHostFuncPtr() const {
2424
VLOG(4) << "Lookup kernel name: " << host_fn_name_;
2525
void* ptr = backend_compiler_->Lookup(host_fn_name_);
2626
PADDLE_ENFORCE_NOT_NULL(ptr,
27-
phi::errors::InvalidArgument(
27+
::common::errors::InvalidArgument(
2828
"Can't find kernel function %s", host_fn_name_));
2929
return ptr;
3030
}
@@ -34,8 +34,8 @@ void* BackendResource::GetInferFuncPtr() const {
3434
void* ptr = backend_compiler_->Lookup(infer_fn_name_);
3535
PADDLE_ENFORCE_NOT_NULL(
3636
ptr,
37-
phi::errors::InvalidArgument("Can't find infer shape function %s",
38-
infer_fn_name_));
37+
::common::errors::InvalidArgument("Can't find infer shape function %s",
38+
infer_fn_name_));
3939
return ptr;
4040
}
4141

@@ -61,7 +61,7 @@ const CompilationCache::CacheValue& CompilationCache::Get(
6161
PADDLE_ENFORCE_EQ(
6262
Has(key),
6363
true,
64-
phi::errors::NotFound("%s is not in CompliatonCache.", key));
64+
::common::errors::NotFound("%s is not in CompliatonCache.", key));
6565
return cache_.at(key);
6666
}
6767

@@ -71,6 +71,12 @@ pir::CINNKernelInfo CompilationCache::GetKernelInfo(const CacheKey& key) const {
7171

7272
void CompilationCache::Insert(const CacheKey& key, const CacheValue& value) {
7373
VLOG(6) << "Insert CompilationCache for: " << key;
74+
PADDLE_ENFORCE_EQ(Has(key),
75+
false,
76+
::common::errors::PreconditionNotMet(
77+
"%s is already in CompliatonCache while calling "
78+
"CompilationCache::Insert().",
79+
key));
7480
cache_.insert({key, value});
7581
}
7682

paddle/cinn/hlir/framework/pir/compilation_cache.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class CompilationCache {
9393
const CacheValue& Get(const CacheKey& key) const;
9494
void Insert(const CacheKey& key, const CacheValue& value);
9595
void Clear();
96+
size_t Size() const { return cache_.size(); }
9697

9798
pir::CINNKernelInfo GetKernelInfo(const CacheKey& key) const;
9899

paddle/cinn/hlir/framework/pir/compilation_task.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ std::shared_ptr<pir::CompilationResult> CompilationTask::BuildPirCINNKernelInfo(
8383
VLOG(5) << "Start to compile module into cuda kernel...";
8484
backend_resource->GetBackendCompiler()->Build(module, "");
8585
compilation_result->SetBackendResource(backend_resource);
86+
VLOG(5) << "End to compile module into cuda kernel.";
8687
return compilation_result;
8788
}
8889

paddle/cinn/hlir/framework/pir/fusion_info.cc

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
#include "paddle/cinn/hlir/framework/pir/fusion_info.h"
1616
#include "paddle/common/enforce.h"
17+
#include "paddle/common/flags.h"
1718
#include "paddle/pir/include/core/ir_printer.h"
19+
PD_DECLARE_bool(enable_cinn_compile_cache);
1820

1921
namespace cinn::hlir::framework::pir {
2022

@@ -46,10 +48,12 @@ std::ostream& operator<<(std::ostream& os, const ValueInfo& value_info) {
4648

4749
OperationInfo::OperationInfo(const ::pir::Operation& op) {
4850
name_ = op.name();
51+
input_infos_.reserve(op.num_operands());
4952
for (const auto value : op.operands_source()) {
5053
if (!value || !value.type()) continue;
5154
input_infos_.emplace_back(value);
5255
}
56+
output_infos_.reserve(op.num_results());
5357
for (const auto value : op.results()) {
5458
if (!value || !value.type()) continue;
5559
output_infos_.emplace_back(value);
@@ -58,6 +62,7 @@ OperationInfo::OperationInfo(const ::pir::Operation& op) {
5862
const auto& attributes = op.attributes();
5963
std::map<std::string, ::pir::Attribute, std::less<>> order_attributes(
6064
attributes.begin(), attributes.end());
65+
attr_infos_.reserve(attributes.size());
6166
for (const auto& [attr_name, attr_value] : order_attributes) {
6267
if (!attr_value || attr_name == kOpCallStack) continue;
6368
attr_infos_.emplace_back(attr_name, attr_value);
@@ -85,9 +90,53 @@ std::ostream& operator<<(std::ostream& os, const OperationInfo& op_info) {
8590
return os;
8691
}
8792

93+
std::size_t FusionOpInfo::hash() const {
94+
std::size_t seed = op_info_.hash();
95+
for (const auto& [value_index, op_info_hash] : inner_deps_) {
96+
hash_combine(seed, value_index);
97+
hash_combine(seed, op_info_hash);
98+
}
99+
return seed;
100+
}
101+
102+
std::ostream& operator<<(std::ostream& os, const FusionOpInfo& info) {
103+
os << info.op_info_ << ", inner_deps:{";
104+
for (const auto& [value_index, op_info_hash] : info.inner_deps_) {
105+
os << " (" << value_index << ", " << op_info_hash << ")";
106+
}
107+
os << "}";
108+
return os;
109+
}
110+
88111
FusionInfo::FusionInfo(const OpLoweringGroup& group) {
89-
for (const auto* op : TopologySort(group)) {
90-
op_infos_.emplace_back(*op);
112+
std::unordered_map<const ::pir::Operation*, size_t> op_mapper;
113+
unique_fn_name_ = group.FuncName();
114+
115+
const auto GetInnerUpstreamOps =
116+
[&](const ::pir::Operation* op) -> decltype(auto) {
117+
std::unordered_map<size_t, size_t> upstream_ops_index_hash;
118+
for (size_t i = 0; i < op->num_operands(); ++i) {
119+
const auto value = op->operand_source(i);
120+
if (!value || !value.defining_op()) continue;
121+
const auto* defining_op = value.defining_op();
122+
if (op_mapper.count(defining_op) == 0) continue;
123+
PADDLE_ENFORCE_LT(op_mapper[defining_op],
124+
this->op_infos_.size(),
125+
::common::errors::OutOfRange(
126+
"Required op_mapper[defining_op] < "
127+
"op_infos_.size(), but received index %d",
128+
op_mapper[defining_op]));
129+
upstream_ops_index_hash.emplace(
130+
i, this->op_infos_[op_mapper[defining_op]].hash());
131+
}
132+
return upstream_ops_index_hash;
133+
};
134+
135+
const auto sorted_ops = TopologySort(group);
136+
for (size_t i = 0; i < sorted_ops.size(); ++i) {
137+
const auto& op = sorted_ops[i];
138+
op_infos_.emplace_back(*op, GetInnerUpstreamOps(op));
139+
op_mapper.insert({op, i});
91140
}
92141
}
93142

@@ -97,13 +146,16 @@ std::size_t FusionInfo::hash() const {
97146
}
98147
std::size_t seed = 2153;
99148
for (const auto& info : op_infos_) hash_combine(seed, info);
149+
if (!FLAGS_enable_cinn_compile_cache) hash_combine(seed, unique_fn_name_);
100150
return seed;
101151
}
102152

103153
std::ostream& operator<<(std::ostream& os, const FusionInfo& fusion_info) {
104154
os << "FusionInfo - " << fusion_info.hash();
105155
if (VLOG_IS_ON(5)) {
106156
os << "{\n";
157+
if (!FLAGS_enable_cinn_compile_cache)
158+
os << "fn_name: " << fusion_info.unique_fn_name_;
107159
for (const auto& op_info : fusion_info.op_infos_) os << op_info << "\n";
108160
os << "}\n";
109161
}

paddle/cinn/hlir/framework/pir/fusion_info.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,21 @@ class OperationInfo {
5757
std::vector<AttributeInfo> attr_infos_;
5858
};
5959

60+
class FusionOpInfo {
61+
public:
62+
FusionOpInfo(const ::pir::Operation &op,
63+
const std::unordered_map<size_t, size_t> &deps)
64+
: op_info_(op), inner_deps_(deps) {}
65+
66+
std::size_t hash() const;
67+
friend std::ostream &operator<<(std::ostream &os, const FusionOpInfo &info);
68+
69+
private:
70+
OperationInfo op_info_;
71+
// oprand_source id : OperationInfo hash
72+
std::unordered_map<size_t, size_t> inner_deps_;
73+
};
74+
6075
class FusionInfo {
6176
using IntArgsMap = std::map<int, CINNKernelInfo::ArgDimIdx>;
6277

@@ -74,13 +89,18 @@ class FusionInfo {
7489
friend std::ostream &operator<<(std::ostream &os, const FusionInfo &info);
7590

7691
private:
77-
std::vector<OperationInfo> op_infos_;
92+
std::vector<FusionOpInfo> op_infos_;
7893
std::size_t cached_hash_value_{0};
94+
95+
// Used to make same subgraphs have unique FusionInfo while
96+
// FLAGS_enable_cinn_compile_cache = false, default empty;
97+
std::string unique_fn_name_{""};
7998
};
8099

81100
std::ostream &operator<<(std::ostream &os, const AttributeInfo &info);
82101
std::ostream &operator<<(std::ostream &os, const ValueInfo &info);
83102
std::ostream &operator<<(std::ostream &os, const OperationInfo &info);
103+
std::ostream &operator<<(std::ostream &os, const FusionOpInfo &info);
84104
std::ostream &operator<<(std::ostream &os, const FusionInfo &info);
85105

86106
// See boost.hash_combine for details
@@ -114,5 +134,6 @@ namespace std {
114134
REGISTER_STD_HASH(AttributeInfo);
115135
REGISTER_STD_HASH(ValueInfo);
116136
REGISTER_STD_HASH(OperationInfo);
137+
REGISTER_STD_HASH(FusionOpInfo);
117138
REGISTER_STD_HASH(FusionInfo)
118139
} // namespace std

paddle/cinn/hlir/framework/pir_compiler.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ std::vector<pir::CINNKernelInfo> PirCompiler::Build(
7171
utils::SequenceDispatcher(0, task_size),
7272
/*thread_num=*/thread_size);
7373
}
74+
VLOG(5) << "Finished compiling " << task_size << " Cinn Kernel info.";
7475
ctx_mapper.SetFinalize(true);
7576
ctx_mapper.UpdateGlobalCache();
7677
return ctx_mapper.RecoverKernelInfos();
@@ -115,8 +116,11 @@ CompilationContextMapper::RecoverKernelInfos() {
115116

116117
std::vector<pir::CINNKernelInfo> kernel_infos(fusion_infos_.size());
117118
for (size_t i = 0; i < fusion_infos_.size(); ++i) {
118-
kernel_infos[i] =
119-
CompilationCache::Instance().GetKernelInfo(fusion_infos_[i]);
119+
const auto& compilation_result =
120+
FLAGS_enable_cinn_compile_cache
121+
? CompilationCache::Instance().Get(fusion_infos_[i])
122+
: compilation_results_[i];
123+
kernel_infos[i] = compilation_result->GetKernelInfo();
120124
}
121125
return kernel_infos;
122126
}

paddle/fluid/pybind/pir.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,6 +1871,14 @@ void BindUtils(pybind11::module *m) {
18711871
pybind11::gil_scoped_release release;
18721872
VLOG(4) << "clear CINN CompilationCache and free BackendResource.";
18731873
cinn::hlir::framework::CompilationCache::Instance().Clear();
1874+
#endif
1875+
});
1876+
1877+
m->def("cinn_compilation_cache_size", []() {
1878+
#ifdef PADDLE_WITH_CINN
1879+
pybind11::gil_scoped_release release;
1880+
VLOG(4) << "clear CINN CompilationCache and free BackendResource.";
1881+
return cinn::hlir::framework::CompilationCache::Instance().Size();
18741882
#endif
18751883
});
18761884
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
19+
import paddle
20+
from paddle.base import core
21+
22+
23+
class LayerCase(paddle.nn.Layer):
24+
def __init__(self):
25+
super().__init__()
26+
self.relu = paddle.nn.functional.relu
27+
28+
def triple_full(self):
29+
y1 = paddle.full([4], 1)
30+
y2 = paddle.full([4], 0)
31+
y3 = paddle.full([4], 0)
32+
return y1, y2, y3
33+
34+
def concat_case_1(self):
35+
y1, y2, y3 = self.triple_full()
36+
out = paddle.concat([y1, y2, y3])
37+
return self.relu(out)
38+
39+
def concat_case_2(self):
40+
y1, y2, y3 = self.triple_full()
41+
out = paddle.concat([y2, y1, y3])
42+
return self.relu(out)
43+
44+
def concat_case_3(self):
45+
y1, y2, y3 = self.triple_full()
46+
out = paddle.concat([y3, y2, y1])
47+
return self.relu(out)
48+
49+
def forward(self, x):
50+
outs = []
51+
for fn in [self.concat_case_1, self.concat_case_2, self.concat_case_3]:
52+
# to tigger duplicate subgraph and cache them.
53+
for i in range(3):
54+
outs.append(self.relu(fn()))
55+
outs.append(self.relu(x))
56+
return outs
57+
58+
59+
class TestLayer(unittest.TestCase):
60+
def setUp(self):
61+
self.inputs = (paddle.rand(shape=[12], dtype=paddle.float32),)
62+
self.net = LayerCase()
63+
64+
def eval(self, net, to_static, with_prim=False, with_cinn=False):
65+
if to_static:
66+
paddle.set_flags({'FLAGS_prim_all': with_prim})
67+
if with_cinn:
68+
build_strategy = paddle.static.BuildStrategy()
69+
build_strategy.build_cinn_pass = True
70+
net = paddle.jit.to_static(
71+
net, build_strategy=build_strategy, full_graph=True
72+
)
73+
else:
74+
net = paddle.jit.to_static(net, full_graph=True)
75+
paddle.seed(123)
76+
net.eval()
77+
outs = net(*self.inputs)
78+
return outs
79+
80+
def check_with_flag(self, cache_size):
81+
st_out = self.eval(self.net, to_static=True)
82+
cinn_out = self.eval(
83+
self.net, to_static=True, with_prim=True, with_cinn=True
84+
)
85+
for st, cinn in zip(
86+
paddle.utils.flatten(st_out), paddle.utils.flatten(cinn_out)
87+
):
88+
np.testing.assert_allclose(st.numpy(), cinn.numpy(), atol=1e-6)
89+
90+
# Check cache size
91+
np.testing.assert_equal(
92+
core.pir.cinn_compilation_cache_size(), cache_size
93+
)
94+
95+
def test_ast_prim_cinn(self):
96+
# NOTE(Aurelius84): Deny relu to split fused subgraph.
97+
paddle.set_flags(
98+
{
99+
"FLAGS_deny_cinn_ops": "relu",
100+
"FLAGS_prim_forward_blacklist": "pd_op.relu",
101+
}
102+
)
103+
self.check_with_flag(cache_size=3)
104+
105+
def test_ast_prim_cinn_disable_cache(self):
106+
core.pir.clear_cinn_compilation_cache()
107+
# NOTE(Aurelius84): Deny relu to split fused subgraph.
108+
paddle.set_flags(
109+
{
110+
"FLAGS_deny_cinn_ops": "relu",
111+
"FLAGS_prim_forward_blacklist": "pd_op.relu",
112+
"FLAGS_enable_cinn_compile_cache": False,
113+
}
114+
)
115+
# if disable cinn_compile_caceh, each subgraph will be considered as unqiue.
116+
self.check_with_flag(cache_size=9)
117+
118+
119+
if __name__ == '__main__':
120+
unittest.main()

0 commit comments

Comments
 (0)