Skip to content

Commit aa6ed57

Browse files
add use_pt_kernel Flags to control whether to use pt kernel (#13)
* add use_pt_kernel Flags to control whether to use pt kernel * change the default value to true for cheking pt kernels
1 parent 9b33270 commit aa6ed57

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ DECLARE_bool(check_nan_inf);
5151
DECLARE_bool(enable_unused_var_check);
5252
PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0,
5353
"number of threads for inner op");
54+
DECLARE_bool(use_pt_kernel);
5455

5556
namespace paddle {
5657
namespace framework {
@@ -1155,7 +1156,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
11551156
// phase
11561157

11571158
// VLOG(1) << "Pt KernelFactory: " << pt::KernelFactory::Instance();
1158-
if (pt::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
1159+
if (FLAGS_use_pt_kernel &&
1160+
pt::KernelFactory::Instance().ContainsKernel(type_.c_str())) {
11591161
if (pt_kernel_key_.get() == nullptr || pt_kernel_.get() == nullptr) {
11601162
ChoosePtKernel(*runtime_ctx, *dev_ctx);
11611163
}

paddle/fluid/imperative/prepared_operator.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "paddle/fluid/platform/xpu/xpu_op_list.h"
2323
#endif
2424
DECLARE_bool(check_nan_inf);
25+
DECLARE_bool(use_pt_kernel);
2526

2627
namespace paddle {
2728
namespace imperative {
@@ -205,7 +206,8 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
205206
#endif
206207

207208
// 1. get expected kernel key
208-
if (pt::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
209+
if (FLAGS_use_pt_kernel &&
210+
pt::KernelFactory::Instance().ContainsKernel(op.Type().c_str())) {
209211
auto kernel_name =
210212
ConstructPtKernelName<VarType>(op.Type(), (*op.Info().proto_), ins);
211213
auto inputs = BuildInputMap<VarType>(ins);

paddle/fluid/platform/flags.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,3 +673,17 @@ PADDLE_DEFINE_EXPORTED_int32(get_host_by_name_time, 120,
673673
PADDLE_DEFINE_EXPORTED_bool(
674674
apply_pass_to_program, false,
675675
"It controls whether to apply IR pass to program when using Fleet APIs");
676+
677+
/**
678+
* Pt kernel related FLAG
679+
* Name: FLAGS_use_pt_kernel
680+
* Since Version: 2.2.0
681+
* Value Range: bool, default=false
682+
* Example: FLAGS_use_pt_kernel=true would use the pt kernel to compute in the
683+
* Op.
684+
* Note:
685+
*/
686+
// TODO(chentianyu03): change default value to false before merge into develop
687+
// branch
688+
PADDLE_DEFINE_EXPORTED_bool(use_pt_kernel, true,
689+
"It controls whether to use pt kernel");

0 commit comments

Comments
 (0)