Skip to content

Commit 76401f4

Browse files
author
Feiyu Chan
authored
Remove cache of cuFFT & Disable ONEMKL (PaddlePaddle#59)
1. replace numpy.fft with scipy.fft as numpy<1.20 not support ortho norm 2. remove cache of cufft plans; 3. enhance error checking. 4. default WITH_ONEMKL to OFF
1 parent e968c20 commit 76401f4

File tree

75 files changed

+3597
-775
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+3597
-775
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ project(paddle CXX C)
3939
# TODO(Shibo Tao): remove find_package(CUDA) completely.
4040
find_package(CUDA QUIET)
4141
find_package(MKL CONFIG QUIET)
42-
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" ${MKL_FOUND})
42+
option(WITH_ONEMKL "Compile PaddlePaddle with oneMKL" OFF)
4343
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
4444
option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
4545
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)

paddle/fluid/extension/src/ext_tensor.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
9090
PADDLE_THROW(platform::errors::Unavailable(
9191
"Only GPU related Copy can reach this func."));
9292
}
93-
cudaStreamSynchronize(dev_ctx->stream());
9493
#elif defined(PADDLE_WITH_HIP)
9594
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
9695
int device_num = paddle::platform::GetCurrentDeviceId();
@@ -110,7 +109,6 @@ void DeviceCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
110109
PADDLE_THROW(platform::errors::Unavailable(
111110
"Only GPU related Copy can reach this func."));
112111
}
113-
hipStreamSynchronize(dev_ctx->stream());
114112
#else
115113
PADDLE_THROW(platform::errors::Unavailable(
116114
"This function can only be used if compiled with"

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_subdirectory(io)
2828
add_subdirectory(new_executor)
2929
#ddim lib
3030
proto_library(framework_proto SRCS framework.proto)
31+
proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto)
3132

3233
proto_library(op_def_proto SRCS op_def.proto DEPS framework_proto)
3334
cc_library(op_def_api SRCS op_def_api.cc DEPS op_def_proto boost)

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ pass_library(multihead_matmul_fuse_pass inference)
9595
pass_library(adaptive_pool2d_convert_global_pass inference)
9696
pass_library(unsqueeze2_eltwise_fuse_pass inference)
9797
pass_library(layer_norm_fuse_pass inference)
98+
pass_library(generate_pass DEPS pass_desc_proto)
99+
target_link_libraries(generate_pass pass_desc_proto)
98100
if(WITH_GPU OR WITH_ROCM)
99101
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
100102
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
@@ -156,6 +158,7 @@ cc_test(test_conv_bn_fuse_pass_cc SRCS conv_bn_fuse_pass_tester.cc DEPS conv_bn_
156158
cc_test(test_adaptive_pool2d_convert_global_pass SRCS adaptive_pool2d_convert_global_pass_tester.cc DEPS adaptive_pool2d_convert_global_pass)
157159
cc_test(test_unsqueeze2_eltwise_fuse_pass SRCS unsqueeze2_eltwise_fuse_pass_tester.cc DEPS unsqueeze2_eltwise_fuse_pass)
158160
cc_test(test_layer_norm_fuse_pass_cc SRCS layer_norm_fuse_pass_tester.cc DEPS layer_norm_fuse_pass pass_test_util naive_executor)
161+
cc_test(test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto)
159162
if(WITH_GPU OR WITH_ROCM)
160163
cc_test(test_embedding_eltwise_layernorm_fuse_pass SRCS embedding_eltwise_layernorm_fuse_pass_tester.cc DEPS embedding_eltwise_layernorm_fuse_pass)
161164
cc_test(test_cudnn_placement_pass SRCS cudnn_placement_pass_tester.cc DEPS cudnn_placement_pass)
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/generate_pass.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace ir {
20+
21+
void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
22+
const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
23+
// Traverse all operators to create subgraph.
24+
for (int index = 0; index < block.ops_size(); ++index) {
25+
const proto::OpDesc& op = block.ops(index);
26+
// Create a PDNode for current operator. Use the index as name to avoid
27+
// multiple operators with same type. Get a PDNode from pattern subgraph
28+
// through index in rewrite phase.
29+
PDNode* op_pdnode =
30+
pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
31+
// Create PDNodes for inputs of current operator.
32+
for (const proto::OpDesc::Var& var : op.inputs()) {
33+
for (const std::string& argument : var.arguments()) {
34+
// The input may be the output of other operator.
35+
PDNode* var_pdnode = pattern->RetrieveNode(argument);
36+
if (nullptr == var_pdnode) {
37+
var_pdnode = pattern->NewNode(argument)->AsInput();
38+
} else if (var_pdnode->IsOutput()) {
39+
var_pdnode->AsIntermediate();
40+
}
41+
var_pdnode->assert_is_op_input(op.type());
42+
pattern->AddEdge(var_pdnode, op_pdnode);
43+
}
44+
}
45+
// Create PDNodes for outputs of current operator.
46+
for (const proto::OpDesc::Var& var : op.outputs()) {
47+
for (const std::string& argument : var.arguments()) {
48+
// The output may be the input of other operator.
49+
PDNode* var_pdnode = pattern->RetrieveNode(argument);
50+
if (nullptr == var_pdnode) {
51+
var_pdnode = pattern->NewNode(argument)->AsOutput();
52+
} else if (var_pdnode->IsInput()) {
53+
var_pdnode->AsIntermediate();
54+
}
55+
var_pdnode->assert_is_op_output(op.type());
56+
pattern->AddEdge(op_pdnode, var_pdnode);
57+
}
58+
}
59+
// Set attribute condition for current operator.
60+
for (const proto::OpDesc::Attr& attr : op.attrs()) {
61+
op_pdnode->assert_more([&](Node* x) {
62+
if (x && x->IsOp()) {
63+
OpDesc* op_desc = x->Op();
64+
if (op_desc->HasAttr(attr.name())) {
65+
return GetAttrValue(attr) == op_desc->GetAttr(attr.name());
66+
}
67+
return false;
68+
}
69+
return false;
70+
});
71+
}
72+
}
73+
}
74+
75+
GraphPatternDetector::handle_t GetGenerateRewrite(
76+
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
77+
GraphPatternDetector::handle_t handler = [&](
78+
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) {
79+
// There are some duplicate patterns.
80+
for (auto iter : subgraph) {
81+
if (nullptr == graph->RetrieveNode(iter.second->id())) {
82+
VLOG(3) << "Node [" << iter.second->Name()
83+
<< "] of subgraph has been removed. So skip this optimize.";
84+
return;
85+
}
86+
}
87+
const proto::BlockDesc& block = pass_desc.replace().blocks(0);
88+
// `var_node_maps` record the mapping of variable to the pattern subgraph.
89+
std::map<std::string, Node*> var_node_maps;
90+
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
91+
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
92+
var_node_maps.insert({var_map.replace_var(), node});
93+
}
94+
// Traverse all operators to create subgraph.
95+
for (const proto::OpDesc& op : block.ops()) {
96+
OpDesc op_desc;
97+
std::vector<Node *> in_nodes, out_nodes;
98+
op_desc.SetType(op.type());
99+
// Create Nodes for inputs of current operator.
100+
for (const proto::OpDesc::Var& var : op.inputs()) {
101+
std::vector<std::string> arguments;
102+
for (const std::string& argument : var.arguments()) {
103+
// The input may be mapped on the operator of pattern subgraph.
104+
Node* node = nullptr;
105+
auto iter = var_node_maps.find(argument);
106+
if (var_node_maps.end() == iter) {
107+
VarDesc var_desc(patterns::UniqueKey(argument));
108+
node = graph->CreateVarNode(&var_desc);
109+
var_node_maps.insert({argument, node});
110+
} else {
111+
node = iter->second;
112+
}
113+
in_nodes.push_back(node);
114+
arguments.push_back(node->Name());
115+
}
116+
op_desc.SetInput(var.parameter(), arguments);
117+
}
118+
// Create Nodes for outputs of current operator.
119+
for (const proto::OpDesc::Var& var : op.outputs()) {
120+
std::vector<std::string> arguments;
121+
for (const std::string& argument : var.arguments()) {
122+
// The output may be mapped on the operator of pattern subgraph.
123+
Node* node = nullptr;
124+
auto iter = var_node_maps.find(argument);
125+
if (var_node_maps.end() == iter) {
126+
VarDesc var_desc(patterns::UniqueKey(argument));
127+
node = graph->CreateVarNode(&var_desc);
128+
var_node_maps.insert({argument, node});
129+
} else {
130+
node = iter->second;
131+
}
132+
out_nodes.push_back(node);
133+
arguments.push_back(node->Name());
134+
}
135+
op_desc.SetOutput(var.parameter(), arguments);
136+
}
137+
// Set attribute for current operator.
138+
for (const proto::OpDesc::Attr& attr : op.attrs()) {
139+
op_desc.SetAttr(attr.name(), GetAttrValue(attr));
140+
}
141+
// Create a Node for current operator.
142+
Node* op_node = graph->CreateOpNode(&op_desc);
143+
for (Node* node : in_nodes) {
144+
IR_NODE_LINK_TO(node, op_node);
145+
}
146+
for (Node* node : out_nodes) {
147+
IR_NODE_LINK_TO(op_node, node);
148+
}
149+
}
150+
// Remove nodes that are intermediate.
151+
std::unordered_set<const Node*> remove_nodes;
152+
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) {
153+
remove_nodes.emplace(subgraph.at(pdnode.get()));
154+
}
155+
for (auto iter : var_node_maps) {
156+
remove_nodes.erase(iter.second);
157+
}
158+
GraphSafeRemoveNodes(graph, remove_nodes);
159+
};
160+
return handler;
161+
}
162+
163+
GeneratePass::GeneratePass(const std::string& binary_str) {
164+
multi_pass_desc_.ParseFromString(binary_str);
165+
VerifyDesc();
166+
}
167+
168+
GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc)
169+
: multi_pass_desc_(multi_pass_desc) {
170+
VerifyDesc();
171+
}
172+
173+
void GeneratePass::ApplyImpl(Graph* graph) const {
174+
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
175+
GraphPatternDetector detector;
176+
InitGeneratePattern(pass_desc, detector.mutable_pattern());
177+
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
178+
// The rewrited graph needs to be verified. Current Pass should be skipped
179+
// if validation failed. Rewrite based on the original graph cannot
180+
// implement rollback operation.
181+
VerifyGraph(*graph);
182+
}
183+
}
184+
185+
void GeneratePass::VerifyDesc() const {
186+
PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0,
187+
platform::errors::InvalidArgument(
188+
"Size of PassDesc should not be empty."));
189+
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
190+
// Check inputs/outputs of subgraph should in `var_maps`.
191+
std::set<std::string> pattern_var_sets, replace_var_sets;
192+
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
193+
pattern_var_sets.emplace(var_map.pattern_var());
194+
replace_var_sets.emplace(var_map.replace_var());
195+
}
196+
auto check_vars = [=](std::set<std::string>* var_sets,
197+
const proto::BlockDesc& block) {
198+
for (const proto::OpDesc& op : block.ops()) {
199+
for (const proto::OpDesc::Var& var : op.outputs()) {
200+
for (const std::string& argument : var.arguments()) {
201+
var_sets->emplace(argument);
202+
}
203+
}
204+
}
205+
for (const proto::OpDesc& op : block.ops()) {
206+
for (const proto::OpDesc::Var& var : op.inputs()) {
207+
for (const std::string& argument : var.arguments()) {
208+
PADDLE_ENFORCE_NE(
209+
var_sets->find(argument), var_sets->end(),
210+
platform::errors::InvalidArgument(
211+
"Subgraph of PassDesc has argument [%s] not in `var_maps`.",
212+
argument));
213+
}
214+
}
215+
}
216+
};
217+
check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0));
218+
check_vars(&replace_var_sets, pass_desc.replace().blocks(0));
219+
}
220+
}
221+
222+
bool GeneratePass::VerifyGraph(const Graph& graph) {
223+
// Return true temporarily.
224+
return true;
225+
}
226+
227+
} // namespace ir
228+
} // namespace framework
229+
} // namespace paddle
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
17+
#include "paddle/fluid/framework/ir/pass.h"
18+
#include "paddle/fluid/framework/pass_desc.pb.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
// Generate a substitute pass from protobuf.
25+
class GeneratePass : public Pass {
26+
public:
27+
// from binary_str
28+
explicit GeneratePass(const std::string& binary_str);
29+
// from PassDesc/MultiPassDesc
30+
explicit GeneratePass(const proto::MultiPassDesc& multi_pass_desc);
31+
32+
protected:
33+
void ApplyImpl(Graph* graph) const override;
34+
35+
private:
36+
GeneratePass() = delete;
37+
DISABLE_COPY_AND_ASSIGN(GeneratePass);
38+
// Verify desc
39+
void VerifyDesc() const;
40+
// Verify graph
41+
static bool VerifyGraph(const Graph& graph);
42+
43+
proto::MultiPassDesc multi_pass_desc_;
44+
};
45+
46+
} // namespace ir
47+
} // namespace framework
48+
} // namespace paddle

0 commit comments

Comments
 (0)