@@ -58,6 +58,41 @@ void PSGPUWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
5858 for (auto & op : ops_) {
5959 op->SetIsRuntimeInferShape (true );
6060 }
61+
62+ // reusing memory
63+ auto input_names = device_reader_->GetInputVarNames ();
64+ std::set<std::string> input_names_set (input_names.begin (), input_names.end ());
65+ for (auto & scope : thread_scope_vec_) {
66+ std::vector<Variable*> need_reuse;
67+ for (auto & var : block.AllVars ()) {
68+ std::string name = var->Name ();
69+ if (!var->Persistable ()) {
70+ if (input_names_set.find (var->Name ()) != input_names_set.end ()) {
71+ continue ;
72+ }
73+ auto * ptr = scope->FindLocalVar (var->Name ());
74+ PADDLE_ENFORCE_NE (ptr, nullptr ,
75+ phi::errors::NotFound (" The var %s is not found." , var->Name ()));
76+ need_reuse.push_back (ptr);
77+ }
78+ }
79+ need_reuse_var_vec_[scope] = std::move (need_reuse);
80+ }
81+ {
82+ need_reuse_var_.clear ();
83+ for (auto & var : block.AllVars ()) {
84+ std::string name = var->Name ();
85+ if (!var->Persistable ()) {
86+ if (input_names_set.find (var->Name ()) != input_names_set.end ()) {
87+ continue ;
88+ }
89+ auto * ptr = thread_scope_->FindLocalVar (var->Name ());
90+ PADDLE_ENFORCE_NE (ptr, nullptr ,
91+ phi::errors::NotFound (" The var %s is not found." , var->Name ()));
92+ need_reuse_var_.push_back (ptr);
93+ }
94+ }
95+ }
6196 }
6297}
6398
@@ -400,6 +435,18 @@ void PSGPUWorker::TrainFiles() {
400435 std::chrono::microseconds (200 ));
401436 }
402437 thread_scope = cur_task.scope ;
438+ // tensor share buffer
439+ std::vector<Variable*>& cur_scope_vars = need_reuse_var_vec_[thread_scope];
440+ PADDLE_ENFORCE_EQ (cur_scope_vars.size (), need_reuse_var_.size (),
441+ platform::errors::Fatal (
442+ " reuse vars size must be same." ));
443+ for (size_t i = 0 ; i < need_reuse_var_.size (); i++) {
444+ Variable* child = cur_scope_vars[i];
445+ Variable* parent = need_reuse_var_[i];
446+ if (child->IsType <LoDTensor>()) {
447+ child->GetMutable <LoDTensor>()->ShareBufferWith (*(parent->GetMutable <LoDTensor>()));
448+ }
449+ }
403450 }
404451
405452 if (cur_batch <= 0 ) {
@@ -409,9 +456,11 @@ void PSGPUWorker::TrainFiles() {
409456 total_ins_num += cur_batch;
410457
411458 if (shape_check_flag_.load ()) {
412- VLOG (0 ) << " Begin OpRunAndShapeCheck... "
459+ VLOG (0 ) << " Begin OpRunAndShapeCheck, "
460+ << shape_check_count_.load ();
461+ if (scope_num_ == 1 || shape_check_count_.fetch_sub (1 ) <= 0 ) {
462+ VLOG (0 ) << " End OpRunAndShapeCheck."
413463 << shape_check_count_.load ();
414- if (shape_check_count_.fetch_sub (1 ) <= 0 ) {
415464 shape_check_flag_ = false ;
416465 }
417466 }
@@ -514,6 +563,17 @@ void PSGPUWorker::TrainFiles() {
514563 ++batch_cnt;
515564
516565 if (scope_num_ != 1 ) {
566+ std::vector<Variable*>& cur_scope_vars = need_reuse_var_vec_[thread_scope];
567+ PADDLE_ENFORCE_EQ (cur_scope_vars.size (), need_reuse_var_.size (),
568+ platform::errors::Fatal (
569+ " reuse vars size must be same." ));
570+ for (size_t i = 0 ; i < need_reuse_var_.size (); i++) {
571+ Variable* child = cur_scope_vars[i];
572+ Variable* parent = need_reuse_var_[i];
573+ if (child->IsType <LoDTensor>()) {
574+ parent->GetMutable <LoDTensor>()->ShareBufferWith (*(child->GetMutable <LoDTensor>()));
575+ }
576+ }
517577 device_reader_->get_pack (cur_task.pack );
518578 free_task_queue_.Push (cur_task);
519579 }
0 commit comments