@@ -506,73 +506,75 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
506
506
return desc;
507
507
}
508
508
509
- static void ReplaceAllReduceOp (const Node &node,
510
- proto::BlockDesc *block,
511
- std::vector<OpDesc> *ops) {
509
+ void ReplaceAllReduceOp (const Node &node,
510
+ proto::BlockDesc *block,
511
+ std::vector<OpDesc> *ops) {
512
512
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
513
- ops->emplace_back ();
514
- auto &desc1 = ops->back ();
515
- std::string name = " fake_coalesce_" + std::to_string (ops->size ());
516
- desc1.SetType (" check_memory_continue" );
517
-
518
- ops->emplace_back ();
519
- auto &desc2 = ops->back ();
520
- desc2.SetType (" c_allreduce_sum" );
521
-
522
- if (node.IsWrappedBy <details::OpHandleBase>()) {
523
- details::OpHandleBase &op_handler =
524
- const_cast <Node *>(&node)->Wrapper <details::OpHandleBase>();
513
+ bool is_fused = (node.Name () == " fused_all_reduce" );
514
+ details::OpHandleBase &op_handle =
515
+ const_cast <Node *>(&node)->Wrapper <details::OpHandleBase>();
516
+
517
+ std::string all_reduce_var_name;
518
+ // If fused, add check_memory_continue OP to fuse inputs
519
+ if (is_fused) {
520
+ all_reduce_var_name = " fake_coalesce_" + std::to_string (ops->size ());
521
+ proto::VarDesc var_desc;
522
+ var_desc.set_name (all_reduce_var_name);
523
+ var_desc.mutable_type ()->set_type (proto::VarType::LOD_TENSOR);
524
+ block->mutable_vars ()->Add ()->CopyFrom (var_desc);
525
+ VLOG (4 ) << " add variable for check_memory_continue: "
526
+ << all_reduce_var_name;
525
527
526
- // set inputs
527
- auto in_var_handles = op_handler .Inputs ();
528
+ // get inputs of check_memory_continue
529
+ auto in_var_handles = op_handle .Inputs ();
528
530
std::vector<std::string> in_names;
529
531
for (const auto &in : in_var_handles) {
530
532
if (dynamic_cast <details::DummyVarHandle *>(in) != nullptr ) {
531
533
continue ;
532
534
}
533
535
in_names.emplace_back (in->Name ());
534
536
}
535
- desc1.SetInput (" X" , in_names);
536
-
537
- proto::VarDesc var_desc;
538
- var_desc.set_name (name);
539
- var_desc.mutable_type ()->set_type (proto::VarType::LOD_TENSOR);
540
- block->mutable_vars ()->Add ()->CopyFrom (var_desc);
541
- desc1.SetOutput (" Out" , {name});
542
- desc1.SetOutput (" XOut" , in_names);
543
- VLOG (4 ) << " add variable for check_memory_continue: " << name;
544
-
545
- desc2.SetInput (" X" , {name});
546
- // set outputs
547
- auto out_var_handles = op_handler.Outputs ();
548
- std::vector<std::string> out_names;
549
- for (const auto &out : out_var_handles) {
550
- if (dynamic_cast <details::DummyVarHandle *>(out) != nullptr ) {
551
- continue ;
552
- }
553
- out_names.emplace_back (out->Name ());
554
- }
555
- desc2.SetOutput (" Out" , {name});
556
-
557
- int ring_id = platform::NCCLCommContext::Instance ().GetRingId (
558
- dynamic_cast <details::NCCLOpHandleBase *>(&op_handler)->GetComm ());
559
- desc2.SetAttr (" ring_id" , ring_id);
560
- desc2.SetAttr (" use_calc_stream" , true );
561
537
562
- // handle grad merge
563
- if (dynamic_cast <details::FusedGradMergeAllReduceOpHandle *>(&op_handler)) {
564
- VLOG (4 ) << " FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum" ;
565
- auto cond_name =
566
- dynamic_cast <details::FusedGradMergeAllReduceOpHandle *>(&op_handler)
567
- ->GradMergeCondName ();
568
- desc2.SetInput (" Cond" , {cond_name});
569
- }
538
+ ops->emplace_back ();
539
+ OpDesc &fuse_op_desc = ops->back ();
540
+ fuse_op_desc.SetType (" check_memory_continue" );
541
+ fuse_op_desc.SetInput (" X" , in_names);
542
+ fuse_op_desc.SetOutput (" Out" , {all_reduce_var_name});
543
+ fuse_op_desc.SetOutput (" XOut" , in_names);
544
+ fuse_op_desc.SetAttr (OpProtoAndCheckerMaker::OpRoleAttrName (),
545
+ (static_cast <int >(OpRole::kBackward )));
546
+ } else {
547
+ all_reduce_var_name = op_handle.Inputs ()[0 ]->Name ();
570
548
}
571
549
572
- desc1.SetAttr (OpProtoAndCheckerMaker::OpRoleAttrName (),
573
- (static_cast <int >(OpRole::kBackward )));
574
- desc2.SetAttr (OpProtoAndCheckerMaker::OpRoleAttrName (),
575
- (static_cast <int >(OpRole::kBackward )));
550
+ // add c_allreduce_sum OP
551
+ ops->emplace_back ();
552
+ OpDesc &all_reduce_op_desc = ops->back ();
553
+ all_reduce_op_desc.SetType (" c_allreduce_sum" );
554
+ all_reduce_op_desc.SetInput (" X" , {all_reduce_var_name});
555
+ all_reduce_op_desc.SetOutput (" Out" , {all_reduce_var_name});
556
+
557
+ int ring_id = platform::NCCLCommContext::Instance ().GetRingId (
558
+ dynamic_cast <details::NCCLOpHandleBase *>(&op_handle)->GetComm ());
559
+ all_reduce_op_desc.SetAttr (" ring_id" , ring_id);
560
+ all_reduce_op_desc.SetAttr (" use_calc_stream" , true );
561
+ all_reduce_op_desc.SetAttr (OpProtoAndCheckerMaker::OpRoleAttrName (),
562
+ (static_cast <int >(OpRole::kBackward )));
563
+
564
+ // handle grad merge
565
+ if (dynamic_cast <details::FusedGradMergeAllReduceOpHandle *>(&op_handle)) {
566
+ VLOG (4 ) << " FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum" ;
567
+ const std::string cond_name =
568
+ dynamic_cast <details::FusedGradMergeAllReduceOpHandle *>(&op_handle)
569
+ ->GradMergeCondName ();
570
+ all_reduce_op_desc.SetInput (" Cond" , {cond_name});
571
+ } else if (dynamic_cast <details::GradMergeAllReduceOpHandle *>(&op_handle)) {
572
+ VLOG (4 ) << " GradMergeAllReduceOpHandle: add cond to c_allreduce_sum" ;
573
+ const std::string cond_name =
574
+ dynamic_cast <details::GradMergeAllReduceOpHandle *>(&op_handle)
575
+ ->GradMergeCondName ();
576
+ all_reduce_op_desc.SetInput (" Cond" , {cond_name});
577
+ }
576
578
#else
577
579
PADDLE_THROW (
578
580
platform::errors::Unimplemented (" ReplaceAllReduceOp is only implemented "
@@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
629
631
for (Node *n : nodes) {
630
632
// if node is not Op, skip
631
633
if (!n->IsOp ()) continue ;
632
-
633
634
// create fill_constant op
634
635
if (n->Name () == " scale_loss_grad" ) {
635
636
VLOG (4 ) << " convert op node scale_loss_grad to desc fill_constant" ;
636
637
ops->emplace_back ();
637
638
auto &desc = ops->back ();
638
639
ReplaceScaleLossGradOp (*n, &desc);
639
- } else if (n->Name () == " fused_all_reduce" ) {
640
- VLOG (4 ) << " convert op node fused_all_reduce to desc c_allreduce_sum" ;
640
+ } else if (n->Name () == " allreduce " || n-> Name () == " fused_all_reduce" ) {
641
+ VLOG (4 ) << " convert op node " << n-> Name () << " to desc c_allreduce_sum" ;
641
642
ReplaceAllReduceOp (*n, block, ops);
642
643
VLOG (4 ) << n->ToString ();
643
644
} else if (n->Op ()) {
0 commit comments