Skip to content

Commit 2f9ec80

Browse files
committed
refine datatransfer
1 parent 0271d70 commit 2f9ec80

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

paddle/fluid/framework/new_executor/data_transfer.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
319319
}
320320
}
321321

322+
bool transfered = false;
322323
DataTranferHelper data_transfer_helper(place, var_scope);
323324
for (auto& var_name_item : *ins_map_temp) {
324325
bool should_skip_input =
@@ -389,6 +390,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
389390
}
390391

391392
if (is_transferred) {
393+
transfered = true;
392394
// update RuntimeContext.inputs and original op_func_node inputs
393395
op_func_node->input_index[var_name_item.first][i] =
394396
var_scope->VarId(new_var_name);
@@ -426,11 +428,13 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
426428
}
427429
}
428430

429-
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
430-
// with instruction. (hot fix, it is not good design here)
431-
op_func_node->operator_base_ =
432-
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
433-
op_base->Type(), new_ins, new_outs, op_base->Attrs()));
431+
if (transfered) {
432+
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
433+
// with instruction. (hot fix, it is not good design here)
434+
op_func_node->operator_base_ =
435+
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
436+
op_base->Type(), new_ins, new_outs, op_base->Attrs()));
437+
}
434438
op_func_node->no_data_transform_index = std::move(no_data_transform_index);
435439
}
436440

paddle/fluid/framework/new_executor/interpretercore_util.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,7 @@ void build_op_func_list(const platform::Place& place,
432432
// see OperatorWithKernel::RunImpl in operator.cc for why
433433
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
434434
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
435-
InterpretercoreInferShapeContext infer_shape_ctx(*op_with_kernel,
436-
runtime_context);
435+
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
437436
// TODO(Aurelius84): In case of control flow ops, they are NOT
438437
// inheritted from OperatorWithKernel.
439438
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1032,7 +1032,7 @@ set_tests_properties(test_parallel_executor_seresnext_with_reduce_gpu PROPERTIES
10321032
set_tests_properties(test_dropout_op PROPERTIES TIMEOUT 120)
10331033
set_tests_properties(test_argsort_op PROPERTIES TIMEOUT 120)
10341034
set_tests_properties(test_gather_nd_op PROPERTIES TIMEOUT 120)
1035-
set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 180)
1035+
set_tests_properties(test_nn_grad PROPERTIES TIMEOUT 300)
10361036
set_tests_properties(test_elementwise_sub_op PROPERTIES TIMEOUT 120)
10371037
set_tests_properties(test_row_conv_op PROPERTIES TIMEOUT 120)
10381038
set_tests_properties(test_parallel_executor_seresnext_with_fuse_all_reduce_gpu PROPERTIES TIMEOUT 120)

0 commit comments

Comments
 (0)