Skip to content

Commit 9495708

Browse files
authored
[Cherry-pick2.3] Optimize dygraph performance part3 (#42256)
* Change small vector size (#42202) * change samll vector size * Update type_defs.h * Optimize dygraph InferShape perf (#42155) * init commit * remove two hash impl * fix bug * polish details * fix compile failed * fix compile failed * fix compile failed * add default kernel sig cache * fix get kernel arg defs error * remove kernel arg defs cache * fix origin op execute
1 parent f16087e commit 9495708

File tree

13 files changed

+176
-101
lines changed

13 files changed

+176
-101
lines changed

paddle/fluid/framework/infershape_utils.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -402,11 +402,11 @@ std::vector<phi::MetaTensor*> CompatInferMetaContext::MutableOutputBetween(
402402
CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
403403
const std::string& op_type) {
404404
// 1. get kernel args
405-
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
405+
auto* arg_map_fn = ctx->GetPhiArgumentMappingFn();
406406
InferShapeArgumentMappingContext arg_map_context(*ctx);
407-
KernelSignature signature =
408-
arg_map_fn ? (*arg_map_fn)(arg_map_context)
409-
: phi::DefaultKernelSignatureMap::Instance().Get(op_type);
407+
phi::KernelSignature signature = arg_map_fn
408+
? (*arg_map_fn)(arg_map_context)
409+
: *ctx->GetPhiDefaultKernelSignature();
410410
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;
411411

412412
// 2. build infermeta context

paddle/fluid/framework/new_executor/new_executor_defs.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ void InterpretercoreInferShapeContext::SetOutputsDim(
393393
SetDims(vars, dims);
394394
}
395395

396+
const phi::ArgumentMappingFn*
397+
InterpretercoreInferShapeContext::GetPhiArgumentMappingFn() const {
398+
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
399+
}
400+
401+
const phi::KernelSignature*
402+
InterpretercoreInferShapeContext::GetPhiDefaultKernelSignature() const {
403+
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
404+
}
405+
396406
void InterpretercoreInferShapeContext::SetSkipLoD(bool skip) {
397407
can_skip_lod_ = skip;
398408
}

paddle/fluid/framework/new_executor/new_executor_defs.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
111111
void SetOutputsDim(const std::string& name,
112112
const std::vector<DDim>& dims) override;
113113

114+
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override;
115+
116+
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override;
117+
114118
void SetSkipLoD(bool skip);
115119

116120
protected:

paddle/fluid/framework/op_desc.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,14 @@ class CompileTimeInferShapeContext : public InferShapeContext {
271271
SetDims(names, dims);
272272
}
273273

274+
const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
275+
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
276+
}
277+
278+
const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
279+
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
280+
}
281+
274282
protected:
275283
std::vector<proto::VarType::Type> GetVarTypes(
276284
const std::vector<std::string> &names) const {

paddle/fluid/framework/operator.cc

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,14 @@ class RuntimeInferShapeContext : public InferShapeContext {
10071007
SetDims(vars, dims);
10081008
}
10091009

1010+
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
1011+
return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
1012+
}
1013+
1014+
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
1015+
return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
1016+
}
1017+
10101018
protected:
10111019
DDim GetDim(Variable* var) const {
10121020
PADDLE_ENFORCE_NOT_NULL(
@@ -1279,16 +1287,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
12791287
phi::KernelKey pt_kernel_key;
12801288
std::string pt_kernel_name;
12811289
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
1282-
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
1283-
pt_kernel_signature_.reset(
1284-
new KernelSignature(std::move(GetExpectedPhiKernelArgs(exe_ctx))));
1285-
VLOG(6) << *pt_kernel_signature_.get();
1290+
if (kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
1291+
kernel_signature_.reset(new phi::KernelSignature(
1292+
std::move(GetExpectedPhiKernelArgs(exe_ctx))));
1293+
VLOG(6) << *kernel_signature_.get();
12861294

12871295
kernel_type_.reset(
12881296
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
12891297
dev_ctx = pool.Get(kernel_type_->place_);
12901298

1291-
pt_kernel_name = pt_kernel_signature_->name;
1299+
pt_kernel_name = kernel_signature_->name;
12921300
pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
12931301
pt_kernel_.reset(
12941302
new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
@@ -1303,7 +1311,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
13031311
<< "` not found.";
13041312
}
13051313
} else {
1306-
pt_kernel_name = pt_kernel_signature_->name;
1314+
pt_kernel_name = kernel_signature_->name;
13071315
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
13081316
// But the default library_type is Plain, so we need to modify the
13091317
// library_type here, otherwise it can't work.
@@ -1449,8 +1457,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
14491457
phi::KernelContext pt_kernel_context;
14501458
// Do data transform before building KernelContext
14511459
// TODO(zhiqiu): support TransferInplaceVarsBack
1452-
PreparePhiData(exec_scope, *pt_kernel_, *pt_kernel_signature_,
1453-
runtime_ctx);
1460+
PreparePhiData(exec_scope, *pt_kernel_, *kernel_signature_, runtime_ctx);
14541461
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &pt_kernel_context);
14551462
(*pt_kernel_)(&pt_kernel_context);
14561463
} else {
@@ -1545,14 +1552,14 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
15451552

15461553
phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
15471554
const ExecutionContext& ctx) const {
1548-
pt_kernel_signature_.reset(
1549-
new KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
1550-
VLOG(6) << *pt_kernel_signature_.get();
1555+
kernel_signature_.reset(
1556+
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
1557+
VLOG(6) << *kernel_signature_.get();
15511558

15521559
kernel_type_.reset(
15531560
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
15541561

1555-
auto pt_kernel_name = pt_kernel_signature_->name;
1562+
auto pt_kernel_name = kernel_signature_->name;
15561563
auto pt_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
15571564
pt_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
15581565
pt_kernel_name, pt_kernel_key)));
@@ -2153,16 +2160,16 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
21532160
tensor.layout());
21542161
}
21552162

2156-
KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
2163+
phi::KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
21572164
const ExecutionContext& ctx) const {
21582165
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
21592166
if (arg_map_fn_ == nullptr) {
21602167
auto* arg_map_fn = phi::OpUtilsMap::Instance().GetArgumentMappingFn(type_);
21612168
if (arg_map_fn) {
21622169
arg_map_fn_.reset(new phi::ArgumentMappingFn(*arg_map_fn));
21632170
} else {
2164-
auto func =
2165-
[this](const phi::ArgumentMappingContext& ctx) -> KernelSignature {
2171+
auto func = [this](
2172+
const phi::ArgumentMappingContext& ctx) -> phi::KernelSignature {
21662173
return phi::DefaultKernelSignatureMap::Instance().Get(type_);
21672174
};
21682175
arg_map_fn_.reset(new phi::ArgumentMappingFn(func));
@@ -2173,7 +2180,8 @@ KernelSignature OperatorWithKernel::GetExpectedPhiKernelArgs(
21732180

21742181
Scope* OperatorWithKernel::PreparePhiData(
21752182
const Scope& scope, const phi::Kernel& pt_kernel,
2176-
const KernelSignature& pt_kernel_signature, RuntimeContext* ctx) const {
2183+
const phi::KernelSignature& pt_kernel_signature,
2184+
RuntimeContext* ctx) const {
21772185
const auto& input_names = pt_kernel_signature.input_names;
21782186
auto input_defs = pt_kernel.args_def().input_defs();
21792187
PADDLE_ENFORCE_EQ(input_names.size(), input_defs.size(),
@@ -2271,9 +2279,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
22712279
phi::KernelContext* pt_kernel_context) const {
22722280
pt_kernel_context->SetDeviceContext(dev_ctx);
22732281

2274-
auto& input_names = pt_kernel_signature_->input_names;
2275-
auto& attr_names = pt_kernel_signature_->attr_names;
2276-
auto& output_names = pt_kernel_signature_->output_names;
2282+
auto& input_names = kernel_signature_->input_names;
2283+
auto& attr_names = kernel_signature_->attr_names;
2284+
auto& output_names = kernel_signature_->output_names;
22772285

22782286
auto input_defs = pt_kernel_->args_def().input_defs();
22792287
auto attr_defs = pt_kernel_->args_def().attribute_defs();

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ class OperatorWithKernel : public OperatorBase {
632632
phi::KernelContext* pt_kernel_context) const;
633633

634634
phi::KernelSignature* PhiKernelSignature() const {
635-
return pt_kernel_signature_.get();
635+
return kernel_signature_.get();
636636
}
637637

638638
phi::Kernel* PhiKernel() const { return pt_kernel_.get(); }
@@ -704,7 +704,7 @@ class OperatorWithKernel : public OperatorBase {
704704
// we may polish the implementation here
705705
mutable bool run_phi_kernel_ = false;
706706
mutable bool run_kp_kernel = false;
707-
mutable std::unique_ptr<phi::KernelSignature> pt_kernel_signature_;
707+
mutable std::unique_ptr<phi::KernelSignature> kernel_signature_;
708708
mutable std::unique_ptr<phi::Kernel> pt_kernel_;
709709
mutable std::unique_ptr<phi::ArgumentMappingFn> arg_map_fn_;
710710
};

paddle/fluid/framework/phi_utils.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
4545
const paddle::SmallVector<const char*>& GetOutputArgsNames() override;
4646
const paddle::SmallVector<const char*>& GetAttrsArgsNames() override;
4747

48-
KernelSignature GetKernelSignature();
48+
phi::KernelSignature GetKernelSignature();
4949

5050
private:
5151
DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto);
@@ -221,10 +221,10 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
221221
return attr_names_;
222222
}
223223

224-
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
225-
return KernelSignature(phi::TransToPhiKernelName(op_proto_->type()).c_str(),
226-
GetInputArgsNames(), GetAttrsArgsNames(),
227-
GetOutputArgsNames());
224+
phi::KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
225+
return phi::KernelSignature(
226+
phi::TransToPhiKernelName(op_proto_->type()).c_str(), GetInputArgsNames(),
227+
GetAttrsArgsNames(), GetOutputArgsNames());
228228
}
229229

230230
std::once_flag kernel_sig_map_init_flag;

paddle/fluid/framework/phi_utils.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@ limitations under the License. */
4040
namespace paddle {
4141
namespace framework {
4242

43-
using KernelSignature = phi::KernelSignature;
44-
4543
/* Kernel Key translate */
4644

4745
OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key);

paddle/fluid/framework/shape_inference.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ class InferShapeContext {
113113
virtual paddle::SmallVector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
114114
GetOutputVarPtrs(const std::string &name) const = 0;
115115

116+
virtual const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const = 0;
117+
118+
virtual const phi::KernelSignature *GetPhiDefaultKernelSignature() const = 0;
119+
116120
protected:
117121
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
118122
virtual void SetRepeatedDims(const std::string &name,

paddle/fluid/imperative/infer_shape_context.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
3737
const NameVarMap<VarType>* in, const NameVarMap<VarType>* out,
3838
const framework::AttributeMap* attr,
3939
const framework::AttributeMap* default_attr, const std::string op_type,
40-
const framework::OpKernelType* op_kernel_type = nullptr)
40+
const framework::OpKernelType* op_kernel_type = nullptr,
41+
const phi::ArgumentMappingFn* arg_map_fn = nullptr,
42+
const phi::KernelSignature* default_kernel_signature = nullptr)
4143
: var_map_in_(in),
4244
var_map_out_(out),
4345
attrs_(attr),
4446
default_attrs_(default_attr),
4547
op_type_(op_type),
46-
op_kernel_type_(op_kernel_type) {}
48+
op_kernel_type_(op_kernel_type),
49+
arg_map_fn_(arg_map_fn),
50+
default_kernel_signature_(default_kernel_signature) {}
4751

4852
bool HasInput(const std::string& name) const override {
4953
// has only one input
@@ -377,6 +381,14 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
377381
"SetLoDLevel function not support in dygraph mode"));
378382
}
379383

384+
const phi::ArgumentMappingFn* GetPhiArgumentMappingFn() const override {
385+
return arg_map_fn_;
386+
}
387+
388+
const phi::KernelSignature* GetPhiDefaultKernelSignature() const override {
389+
return default_kernel_signature_;
390+
}
391+
380392
protected:
381393
DDim GetDim(framework::Variable* var) const {
382394
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
@@ -438,6 +450,9 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
438450
const framework::AttributeMap* default_attrs_;
439451
const std::string op_type_;
440452
const framework::OpKernelType* op_kernel_type_;
453+
// arg_map_fn_ and default_kernel_signature_ may be nullptr
454+
const phi::ArgumentMappingFn* arg_map_fn_;
455+
const phi::KernelSignature* default_kernel_signature_;
441456
};
442457

443458
} // namespace imperative

0 commit comments

Comments
 (0)