@@ -108,7 +108,9 @@ void BasicEngine::CheckBackwardInputs(const OpBase& op) {
108108 }
109109}
110110
111- void BasicEngine::PrepareGradAccumulators (const OpBase& op) {
111+ void BasicEngine::PrepareGradAccumulators (
112+ const OpBase& op,
113+ const std::vector<std::shared_ptr<GradOpNode>>& grad_pending_nodes) {
112114 for (const auto & pair : op.GetOutsMap ()) {
113115 if (!pair.second .IsGrad ()) {
114116 continue ;
@@ -117,29 +119,94 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
117119 for (const auto & var : pair.second ) {
118120 if (!var) continue ;
119121
120- auto & accumulator = accumulators_[var.get ()];
121- if (!accumulator) {
122- if (FLAGS_sort_sum_gradient) {
123- accumulator.reset (new SortedGradientAccumulator (var.get ()));
124- } else {
125- accumulator.reset (new EagerGradientAccumulator (var.get ()));
122+ if (!var->HasGradNode ()) {
123+ auto & accumulator = accumulators_[var.get ()];
124+ if (!accumulator) {
125+ if (FLAGS_sort_sum_gradient) {
126+ accumulator.reset (new SortedGradientAccumulator (var.get ()));
127+ } else {
128+ accumulator.reset (new EagerGradientAccumulator (var.get ()));
129+ }
126130 }
127- }
128131
129- accumulator->IncreaseRefCnt ();
132+ accumulator->IncreaseRefCnt ();
130133
131- VLOG (3 ) << " Prepare to acccumulate variable grad " << var->Name () << " ("
132- << var.get () << " ) with reference count "
133- << accumulator->RefCnt ();
134+ VLOG (3 ) << " Prepare to acccumulate variable grad " << var->Name () << " ("
135+ << var.get ()
136+ << " ) that don't have grad node with reference count "
137+ << accumulator->RefCnt ();
138+
139+ if (var->HasLeafHooks ()) {
140+ VLOG (3 ) << " Grad variable wrapper (" << var->Name ()
141+ << " ) has leaf grad hooks." ;
142+ PADDLE_ENFORCE_NE (
143+ var->HasGradNode (), true ,
144+ platform::errors::PermissionDenied (
145+ " Only leaf Tensor's gradient can append hook to "
146+ " Gradientaccumulator." ));
147+ accumulator->SetPostHooks (var->GetLeafHooks ());
148+ }
149+ } else {
150+ // Because Inplace op overwrites the grad_node of the input grad_var. So
151+ // only the information of grad_pending_node can be used to find the
152+ // grad_node of grad_var.
153+ bool find_grad_node_of_var = false ;
154+ for (auto & grad_pending_node : grad_pending_nodes) {
155+ PADDLE_ENFORCE_NOT_NULL (
156+ grad_pending_node,
157+ platform::errors::NotFound (" Grad pending node is nullptr." ));
158+ for (auto & grad_pending_op : *grad_pending_node) {
159+ VLOG (6 ) << " Determine whether var (" << var->Name ()
160+ << " ) is the input var of grad_pending_op ("
161+ << grad_pending_op.Type () << " )." ;
162+ grad_pending_op.EnforceHasInOut ();
163+ for (const auto & grad_pending_op_ins_pair :
164+ grad_pending_op.GetInsMap ()) {
165+ if (!grad_pending_op_ins_pair.second .IsGrad ()) {
166+ continue ;
167+ }
168+ for (const auto & pending_in_var :
169+ grad_pending_op_ins_pair.second ) {
170+ if (var == pending_in_var) {
171+ VLOG (6 ) << " Var (" << var->Name ()
172+ << " ) is the input var of grad_pending_op ("
173+ << grad_pending_op.Type () << " )." ;
174+ find_grad_node_of_var = true ;
175+ break ;
176+ }
177+ }
178+ if (find_grad_node_of_var) {
179+ break ;
180+ }
181+ }
182+ }
134183
135- if (var->HasLeafHooks ()) {
136- VLOG (3 ) << " Grad variable wrapper (" << var->Name ()
137- << " ) has leaf grad hooks." ;
138- PADDLE_ENFORCE_NE (var->HasGradNode (), true ,
139- platform::errors::PermissionDenied (
140- " Only leaf Tensor's gradient can append hook to "
141- " Gradientaccumulator." ));
142- accumulator->SetPostHooks (var->GetLeafHooks ());
184+ if (find_grad_node_of_var) {
185+ auto & accumulator =
186+ accumulators_with_grad_node_[grad_pending_node][var.get ()];
187+
188+ if (!accumulator) {
189+ if (FLAGS_sort_sum_gradient) {
190+ accumulator.reset (new SortedGradientAccumulator (var.get ()));
191+ } else {
192+ accumulator.reset (new EagerGradientAccumulator (var.get ()));
193+ }
194+ }
195+
196+ accumulator->IncreaseRefCnt ();
197+
198+ VLOG (3 ) << " Prepare to acccumulate variable grad " << var->Name ()
199+ << " (" << var.get ()
200+ << " ) that has grad node with reference count "
201+ << accumulator->RefCnt ();
202+ break ;
203+ }
204+ }
205+ PADDLE_ENFORCE_EQ (
206+ find_grad_node_of_var, true ,
207+ platform::errors::NotFound (
208+ " No grad node corresponding to grad Tensor (%s) was found." ,
209+ var->Name ()));
143210 }
144211 }
145212 }
@@ -148,10 +215,13 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
148215void BasicEngine::PrepareDeps () {
149216 PADDLE_ENFORCE_EQ (
150217 node_deps_.empty (), true ,
151- platform::errors::AlreadyExists (" Op deps must be initialized here " ));
218+ platform::errors::AlreadyExists (" Op deps must be initialized. " ));
152219 PADDLE_ENFORCE_EQ (
153220 accumulators_.empty (), true ,
154- platform::errors::AlreadyExists (" Accumulators must be initialized here" ));
221+ platform::errors::AlreadyExists (" Accumulators must be initialized." ));
222+ PADDLE_ENFORCE_EQ (
223+ accumulators_with_grad_node_.empty (), true ,
224+ platform::errors::AlreadyExists (" Accumulators must be initialized." ));
155225
156226 std::queue<GradOpNode*> q;
157227 std::unordered_set<GradOpNode*> visited;
@@ -163,16 +233,17 @@ void BasicEngine::PrepareDeps() {
163233 auto * cur_node = q.front ();
164234 q.pop ();
165235
236+ const auto & grad_pending_nodes = cur_node->GradPendingNodes ();
237+
166238 for (auto & cur_op : *cur_node) {
167239 cur_op.EnforceHasInOut ();
168- PrepareGradAccumulators (cur_op);
240+ PrepareGradAccumulators (cur_op, grad_pending_nodes );
169241 }
170242
171- const auto & grad_pending_nodes = cur_node->GradPendingNodes ();
172243 for (auto & grad_pending_node : grad_pending_nodes) {
173244 PADDLE_ENFORCE_NOT_NULL (
174245 grad_pending_node,
175- platform::errors::NotFound (" Grad pending node should not be null " ));
246+ platform::errors::NotFound (" Grad pending node is nullptr. " ));
176247 ++node_deps_[grad_pending_node.get ()];
177248 if (visited.count (grad_pending_node.get ()) == 0 ) {
178249 visited.insert (grad_pending_node.get ());
@@ -198,6 +269,8 @@ void BasicEngine::Execute() {
198269 auto shared_cur_node = std::move (q.front ());
199270 q.pop ();
200271
272+ auto & inplace_grad_name_map = shared_cur_node->InplaceGradNameMap ();
273+
201274 for (auto & cur_op : *shared_cur_node) {
202275 ++op_num;
203276
@@ -222,11 +295,38 @@ void BasicEngine::Execute() {
222295 continue ;
223296 }
224297
225- auto iter = accumulators_.find (var.get ());
226- PADDLE_ENFORCE_EQ (
227- iter != accumulators_.end (), true ,
228- platform::errors::NotFound (" Cannot find gradient of variable %s" ,
229- var->Name ()));
298+ std::unordered_map<VariableWrapper*,
299+ std::unique_ptr<GradientAccumulator>>::iterator
300+ iter;
301+ if (!var->HasGradNode ()) {
302+ VLOG (10 ) << " Find gradient of var (" << var->Name ()
303+ << " ) with no grad_node." ;
304+ iter = accumulators_.find (var.get ());
305+ PADDLE_ENFORCE_EQ (
306+ iter != accumulators_.end (), true ,
307+ platform::errors::NotFound (
308+ " Cannot find gradient of variable %s" , var->Name ()));
309+ } else {
310+ bool flag_find_grad = false ;
311+ VLOG (10 ) << " Find gradient of var (" << var->Name ()
312+ << " ) with grad_node." ;
313+ for (auto & grad_pending_node :
314+ shared_cur_node->GradPendingNodes ()) {
315+ const auto & iter_grad_node =
316+ accumulators_with_grad_node_.find (grad_pending_node);
317+ if (iter_grad_node != accumulators_with_grad_node_.end ()) {
318+ iter = iter_grad_node->second .find (var.get ());
319+ if (iter != iter_grad_node->second .end ()) {
320+ flag_find_grad = true ;
321+ break ;
322+ }
323+ }
324+ }
325+ PADDLE_ENFORCE_EQ (
326+ flag_find_grad, true ,
327+ platform::errors::NotFound (
328+ " Cannot find gradient of variable %s" , var->Name ()));
329+ }
230330
231331 // leaf_accumulators_ : hooks and accumulate-grad for leaf tensor
232332 if (var->IsLeafGrad ()) {
@@ -245,6 +345,25 @@ void BasicEngine::Execute() {
245345 need_accu_var_list_.emplace_back (iter->second .get (), var);
246346 VLOG (10 ) << " create temporary var of " << var->Name ()
247347 << " for sum gradient within this graph!" ;
348+ } else if (!inplace_grad_name_map.empty () &&
349+ inplace_grad_name_map.count (pair.first )) {
350+ // When calculate Inplace grad op, create a new output var.
351+ // If a tmp var has been created, there is no need to create it
352+ // again.
353+ for (auto & in_var :
354+ bwd_ins.at (inplace_grad_name_map.at (pair.first ))) {
355+ if (in_var == var) {
356+ auto tmp_var = std::make_shared<VariableWrapper>(var->Name ());
357+ tmp_var->SetType (var->Type ());
358+ tmp_var->SetForwardDataType (var->ForwardDataType ());
359+ inplace_output_grad_var_list_.emplace_back (var, tmp_var);
360+ var = tmp_var;
361+ VLOG (10 ) << " Inplace grad op does not use the Inplace "
362+ " strategy, a temporary output var ("
363+ << var->Name () << " ) will be created." ;
364+ break ;
365+ }
366+ }
248367 }
249368 }
250369 }
@@ -280,6 +399,10 @@ void BasicEngine::Execute() {
280399 cur_op.place ());
281400 }
282401
402+ for (auto & pair : inplace_output_grad_var_list_) {
403+ *pair.first = std::move (*pair.second );
404+ }
405+
283406 // Step 2: Sum Gradient of This graph
284407 for (auto & pair : need_accu_var_list_) {
285408 pair.first ->SumGrad (std::move (pair.second ), cur_op.id ());
@@ -302,6 +425,7 @@ void BasicEngine::Execute() {
302425 }
303426
304427 need_accu_var_list_.clear ();
428+ inplace_output_grad_var_list_.clear ();
305429 leaf_accumulators_.clear ();
306430
307431 if (!retain_graph_) {
@@ -312,9 +436,9 @@ void BasicEngine::Execute() {
312436
313437 // Step 3: Collect ready ops
314438 for (auto & grad_pending_node : shared_cur_node->GradPendingNodes ()) {
315- PADDLE_ENFORCE_NOT_NULL (grad_pending_node,
316- platform::errors::NotFound (
317- " Grad pending node should not be nullptr" ));
439+ PADDLE_ENFORCE_NOT_NULL (
440+ grad_pending_node,
441+ platform::errors::NotFound ( " Grad pending node is nullptr. " ));
318442 auto iter = node_deps_.find (grad_pending_node.get ());
319443 if (iter == node_deps_.end ()) {
320444 continue ;
@@ -334,6 +458,7 @@ void BasicEngine::Clear() {
334458 init_node_.reset ();
335459 node_deps_.clear ();
336460 accumulators_.clear ();
461+ accumulators_with_grad_node_.clear ();
337462 need_accu_var_list_.clear ();
338463 leaf_accumulators_.clear ();
339464}
0 commit comments