Skip to content

Commit 6a706e6

Browse files
authored
Convert GradMergeAllReduceOpHandle in GraphToBlock (#46544)
* Convert GradMergeAllReduceOpHandle in GraphToBlock * Set FLAGS_CONVERT_GRAPH_TO_PROGRAM to False
1 parent 3fc4fa2 commit 6a706e6

File tree

1 file changed

+59
-58
lines changed

1 file changed

+59
-58
lines changed

paddle/fluid/framework/ir/graph_helper.cc

+59-58
Original file line numberDiff line numberDiff line change
@@ -506,73 +506,75 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
506506
return desc;
507507
}
508508

509-
static void ReplaceAllReduceOp(const Node &node,
510-
proto::BlockDesc *block,
511-
std::vector<OpDesc> *ops) {
509+
void ReplaceAllReduceOp(const Node &node,
510+
proto::BlockDesc *block,
511+
std::vector<OpDesc> *ops) {
512512
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
513-
ops->emplace_back();
514-
auto &desc1 = ops->back();
515-
std::string name = "fake_coalesce_" + std::to_string(ops->size());
516-
desc1.SetType("check_memory_continue");
517-
518-
ops->emplace_back();
519-
auto &desc2 = ops->back();
520-
desc2.SetType("c_allreduce_sum");
521-
522-
if (node.IsWrappedBy<details::OpHandleBase>()) {
523-
details::OpHandleBase &op_handler =
524-
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
513+
bool is_fused = (node.Name() == "fused_all_reduce");
514+
details::OpHandleBase &op_handle =
515+
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
516+
517+
std::string all_reduce_var_name;
518+
// If fused, add check_memory_continue OP to fuse inputs
519+
if (is_fused) {
520+
all_reduce_var_name = "fake_coalesce_" + std::to_string(ops->size());
521+
proto::VarDesc var_desc;
522+
var_desc.set_name(all_reduce_var_name);
523+
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
524+
block->mutable_vars()->Add()->CopyFrom(var_desc);
525+
VLOG(4) << "add variable for check_memory_continue: "
526+
<< all_reduce_var_name;
525527

526-
// set inputs
527-
auto in_var_handles = op_handler.Inputs();
528+
// get inputs of check_memory_continue
529+
auto in_var_handles = op_handle.Inputs();
528530
std::vector<std::string> in_names;
529531
for (const auto &in : in_var_handles) {
530532
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
531533
continue;
532534
}
533535
in_names.emplace_back(in->Name());
534536
}
535-
desc1.SetInput("X", in_names);
536-
537-
proto::VarDesc var_desc;
538-
var_desc.set_name(name);
539-
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
540-
block->mutable_vars()->Add()->CopyFrom(var_desc);
541-
desc1.SetOutput("Out", {name});
542-
desc1.SetOutput("XOut", in_names);
543-
VLOG(4) << "add variable for check_memory_continue: " << name;
544-
545-
desc2.SetInput("X", {name});
546-
// set outputs
547-
auto out_var_handles = op_handler.Outputs();
548-
std::vector<std::string> out_names;
549-
for (const auto &out : out_var_handles) {
550-
if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
551-
continue;
552-
}
553-
out_names.emplace_back(out->Name());
554-
}
555-
desc2.SetOutput("Out", {name});
556-
557-
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
558-
dynamic_cast<details::NCCLOpHandleBase *>(&op_handler)->GetComm());
559-
desc2.SetAttr("ring_id", ring_id);
560-
desc2.SetAttr("use_calc_stream", true);
561537

562-
// handle grad merge
563-
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)) {
564-
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
565-
auto cond_name =
566-
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)
567-
->GradMergeCondName();
568-
desc2.SetInput("Cond", {cond_name});
569-
}
538+
ops->emplace_back();
539+
OpDesc &fuse_op_desc = ops->back();
540+
fuse_op_desc.SetType("check_memory_continue");
541+
fuse_op_desc.SetInput("X", in_names);
542+
fuse_op_desc.SetOutput("Out", {all_reduce_var_name});
543+
fuse_op_desc.SetOutput("XOut", in_names);
544+
fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
545+
(static_cast<int>(OpRole::kBackward)));
546+
} else {
547+
all_reduce_var_name = op_handle.Inputs()[0]->Name();
570548
}
571549

572-
desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
573-
(static_cast<int>(OpRole::kBackward)));
574-
desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
575-
(static_cast<int>(OpRole::kBackward)));
550+
// add c_allreduce_sum OP
551+
ops->emplace_back();
552+
OpDesc &all_reduce_op_desc = ops->back();
553+
all_reduce_op_desc.SetType("c_allreduce_sum");
554+
all_reduce_op_desc.SetInput("X", {all_reduce_var_name});
555+
all_reduce_op_desc.SetOutput("Out", {all_reduce_var_name});
556+
557+
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
558+
dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
559+
all_reduce_op_desc.SetAttr("ring_id", ring_id);
560+
all_reduce_op_desc.SetAttr("use_calc_stream", true);
561+
all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
562+
(static_cast<int>(OpRole::kBackward)));
563+
564+
// handle grad merge
565+
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)) {
566+
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
567+
const std::string cond_name =
568+
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handle)
569+
->GradMergeCondName();
570+
all_reduce_op_desc.SetInput("Cond", {cond_name});
571+
} else if (dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)) {
572+
VLOG(4) << "GradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
573+
const std::string cond_name =
574+
dynamic_cast<details::GradMergeAllReduceOpHandle *>(&op_handle)
575+
->GradMergeCondName();
576+
all_reduce_op_desc.SetInput("Cond", {cond_name});
577+
}
576578
#else
577579
PADDLE_THROW(
578580
platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented "
@@ -629,15 +631,14 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
629631
for (Node *n : nodes) {
630632
// if node is not Op, skip
631633
if (!n->IsOp()) continue;
632-
633634
// create fill_constant op
634635
if (n->Name() == "scale_loss_grad") {
635636
VLOG(4) << "convert op node scale_loss_grad to desc fill_constant";
636637
ops->emplace_back();
637638
auto &desc = ops->back();
638639
ReplaceScaleLossGradOp(*n, &desc);
639-
} else if (n->Name() == "fused_all_reduce") {
640-
VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum";
640+
} else if (n->Name() == "allreduce" || n->Name() == "fused_all_reduce") {
641+
VLOG(4) << "convert op node " << n->Name() << " to desc c_allreduce_sum";
641642
ReplaceAllReduceOp(*n, block, ops);
642643
VLOG(4) << n->ToString();
643644
} else if (n->Op()) {

0 commit comments

Comments
 (0)