Skip to content

Commit ea3ef49

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into eager_dygraph_codegen_joint
2 parents 65e0036 + 472c908 commit ea3ef49

File tree

121 files changed

+9250
-1007
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

121 files changed

+9250
-1007
lines changed

paddle/fluid/distributed/fleet_executor/fleet_executor.cc

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,22 @@ FleetExecutor::FleetExecutor(const std::string& exe_desc_str) {
3333

3434
FleetExecutor::~FleetExecutor() { root_scope_->DropKids(); }
3535

36-
void FleetExecutor::Init(const framework::ProgramDesc& program_desc,
37-
framework::Scope* scope,
38-
const platform::Place& place) {
39-
runtime_graph_ = std::make_shared<RuntimeGraph>(program_desc, exe_desc_);
36+
void FleetExecutor::Init(
37+
const framework::ProgramDesc& program_desc, framework::Scope* scope,
38+
const platform::Place& place, const std::vector<TaskNode*>& task_nodes,
39+
const std::unordered_map<int64_t, int64_t>& task_id_to_rank) {
40+
if (task_nodes.size() == 0) {
41+
runtime_graph_ = std::make_shared<RuntimeGraph>(program_desc, exe_desc_);
42+
} else {
43+
runtime_graph_ = std::make_shared<RuntimeGraph>();
44+
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
45+
for (auto task_node : task_nodes) {
46+
int64_t interceptor_id = task_node->task_id();
47+
interceptor_id_to_task.emplace(interceptor_id, task_node);
48+
}
49+
runtime_graph_->SetInterceptorIdToRank(task_id_to_rank);
50+
runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task);
51+
}
4052
root_scope_ = scope;
4153
place_ = place;
4254
PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument(

paddle/fluid/distributed/fleet_executor/fleet_executor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,17 @@ namespace distributed {
3030
class RuntimeGraph;
3131
class Carrier;
3232
class MessageBus;
33+
class TaskNode;
3334

3435
class FleetExecutor final {
3536
public:
3637
FleetExecutor() = delete;
3738
explicit FleetExecutor(const std::string& exe_desc_str);
3839
~FleetExecutor();
3940
void Init(const framework::ProgramDesc& program_desc, framework::Scope* scope,
40-
const platform::Place& place);
41+
const platform::Place& place,
42+
const std::vector<TaskNode*>& task_nodes,
43+
const std::unordered_map<int64_t, int64_t>& task_id_to_rank);
4144
void Run();
4245

4346
private:

paddle/fluid/distributed/fleet_executor/runtime_graph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@ class RuntimeGraph final {
4444
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank() const {
4545
return intercepter_id_to_rank_;
4646
}
47+
void SetInterceptorIdToRank(
48+
const std::unordered_map<int64_t, int64_t>& intercepter_id_to_rank) {
49+
intercepter_id_to_rank_ = intercepter_id_to_rank;
50+
}
51+
void SetInterceptorIdToNode(
52+
const std::unordered_map<int64_t, TaskNode*>& intercepter_id_to_node) {
53+
intercepter_id_to_node_ = intercepter_id_to_node;
54+
}
4755
std::string DebugString() const;
4856

4957
private:

paddle/fluid/eager/legacy/infer_shape_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,13 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
216216

217217
// TODO(paddle-dev): Can this be template?
218218
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
219-
const std::string& name) override {
219+
const std::string& name) const override {
220220
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
221221
"GetInputVarPtrs not support in dygraph runtime context"));
222222
}
223223

224224
std::vector<paddle::framework::InferShapeVarPtr> GetOutputVarPtrs(
225-
const std::string& name) override {
225+
const std::string& name) const override {
226226
PADDLE_THROW(paddle::platform::errors::PermissionDenied(
227227
"GetOutputVarPtrs not support in dygraph runtime context"));
228228
}

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,15 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
6969
desc->SetOutput("Output",
7070
std::vector<std::string>({activation_out->Name()}));
7171

72-
desc->SetAttr("fuse_activation", activation_type());
72+
if (activation_type() == "gelu" &&
73+
activation->Op()->HasAttr("approximate")) {
74+
bool approximate =
75+
BOOST_GET_CONST(bool, activation->Op()->GetAttr("approximate"));
76+
std::string type = approximate ? "_tanh" : "_erf";
77+
desc->SetAttr("fuse_activation", "gelu" + type);
78+
} else {
79+
desc->SetAttr("fuse_activation", activation_type());
80+
}
7381

7482
// MKLDNN ops use alpha and beta as activation parameters but paddle ops are
7583
// not generalized
@@ -240,6 +248,19 @@ Conv2DHardSigmoidFusePass::Conv2DHardSigmoidFusePass() {
240248
.End();
241249
}
242250

251+
Conv2DGeluFusePass::Conv2DGeluFusePass() {
252+
AddOpCompat(OpCompat("gelu"))
253+
.AddInput("X")
254+
.IsTensor()
255+
.End()
256+
.AddOutput("Out")
257+
.IsTensor()
258+
.End()
259+
.AddAttr("approximate")
260+
.IsType<bool>()
261+
.End();
262+
}
263+
243264
} // namespace ir
244265
} // namespace framework
245266
} // namespace paddle
@@ -294,3 +315,11 @@ REGISTER_PASS_CAPABILITY(conv_hard_sigmoid_mkldnn_fuse_pass)
294315
paddle::framework::compatible::OpVersionComparatorCombination()
295316
.LE("conv2d", 1)
296317
.EQ("hard_sigmoid", 0));
318+
319+
REGISTER_PASS(conv_gelu_mkldnn_fuse_pass,
320+
paddle::framework::ir::Conv2DGeluFusePass);
321+
REGISTER_PASS_CAPABILITY(conv_gelu_mkldnn_fuse_pass)
322+
.AddCombination(
323+
paddle::framework::compatible::OpVersionComparatorCombination()
324+
.LE("conv2d", 1)
325+
.EQ("gelu", 0));

paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ class Conv2DHardSigmoidFusePass : public ConvActivationFusePass {
8181
std::string activation_type() const { return "hard_sigmoid"; }
8282
};
8383

84+
/*
85+
* Fuse Conv and Gelu class
86+
*/
87+
class Conv2DGeluFusePass : public ConvActivationFusePass {
88+
public:
89+
Conv2DGeluFusePass();
90+
std::string activation_type() const { return "gelu"; }
91+
};
92+
8493
} // namespace ir
8594
} // namespace framework
8695
} // namespace paddle

paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ DepthwiseConvMKLDNNPass::DepthwiseConvMKLDNNPass() {
6868
.IsType<std::vector<int>>()
6969
.End()
7070
.AddAttr("data_format")
71-
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
71+
.IsStringIn({"NCHW", "AnyLayout"})
7272
.End();
7373
}
7474

paddle/fluid/framework/new_executor/new_executor_defs.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ bool InterpretercoreInferShapeContext::IsRuntime() const { return true; }
309309

310310
// TODO(paddle-dev): Can this be template?
311311
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs(
312-
const std::string& name) {
312+
const std::string& name) const {
313313
const std::vector<Variable*>& vars = InputVars(name);
314314
std::vector<InferShapeVarPtr> res;
315315
res.reserve(vars.size());
@@ -318,7 +318,8 @@ std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs(
318318
}
319319

320320
std::vector<InferShapeVarPtr>
321-
InterpretercoreInferShapeContext::GetOutputVarPtrs(const std::string& name) {
321+
InterpretercoreInferShapeContext::GetOutputVarPtrs(
322+
const std::string& name) const {
322323
const std::vector<Variable*>& vars = OutputVars(name);
323324
std::vector<InferShapeVarPtr> res;
324325
res.reserve(vars.size());

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
8686

8787
// TODO(paddle-dev): Can this be template?
8888
std::vector<InferShapeVarPtr> GetInputVarPtrs(
89-
const std::string& name) override;
89+
const std::string& name) const override;
9090

9191
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
92-
const std::string& name) override;
92+
const std::string& name) const override;
9393

9494
DDim GetInputDim(const std::string& name) const override;
9595

paddle/fluid/framework/op_desc.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
200200
}
201201

202202
std::vector<InferShapeVarPtr> GetInputVarPtrs(
203-
const std::string &name) override {
203+
const std::string &name) const override {
204204
const std::vector<std::string> arg_names = Inputs(name);
205205
std::vector<InferShapeVarPtr> res;
206206
res.reserve(arg_names.size());
@@ -212,7 +212,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
212212
}
213213

214214
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
215-
const std::string &name) override {
215+
const std::string &name) const override {
216216
const std::vector<std::string> arg_names = Outputs(name);
217217
std::vector<InferShapeVarPtr> res;
218218
res.reserve(arg_names.size());

0 commit comments

Comments
 (0)