@@ -51,45 +51,39 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
5151 VLOG (3 ) << " ProcessGraph" ;
5252 RpcCtxMap send_varname_to_ctx;
5353 RpcCtxMap recv_varname_to_ctx;
54- for (auto i = 0 ; i < graphs.size (); ++i) {
55- std::vector<ir::Node *> nodes_to_delete;
56- for (auto &node : graphs[i]->Nodes ()) {
57- VLOG (3 ) << " node name " << node->Name ();
58- if (node && node->IsOp ()) {
59- if (node->Name () == " send" ) {
60- auto send_var_name = node->Op ()->Input (" X" )[0 ];
61- auto send_varnames = boost::get<std::vector<std::string>>(
62- node->Op ()->GetNullableAttr (" send_varnames" ));
63- auto epmap = boost::get<std::vector<std::string>>(
64- node->Op ()->GetNullableAttr (" epmap" ));
65- auto height_section = boost::get<std::vector<int64_t >>(
66- node->Op ()->GetNullableAttr (" sections" ));
67- auto trainer_id =
68- boost::get<int >(node->Op ()->GetNullableAttr (" trainer_id" ));
69- send_varname_to_ctx[send_var_name] =
70- operators::distributed::RpcContext (send_var_name, send_varnames,
71- epmap, height_section,
72- trainer_id);
73- VLOG (3 ) << " find and init an send op: "
74- << send_varname_to_ctx[send_var_name];
75- } else if (node->Name () == " recv" ) {
76- auto recv_var_name = node->Op ()->Output (" Out" )[0 ];
77- auto recv_varnames = boost::get<std::vector<std::string>>(
78- node->Op ()->GetNullableAttr (" recv_varnames" ));
79- auto epmap = boost::get<std::vector<std::string>>(
80- node->Op ()->GetNullableAttr (" epmap" ));
81- auto trainer_id =
82- boost::get<int >(node->Op ()->GetNullableAttr (" trainer_id" ));
83- recv_varname_to_ctx[recv_var_name] =
84- operators::distributed::RpcContext (recv_var_name, recv_varnames,
85- epmap, {}, trainer_id);
86- nodes_to_delete.push_back (node);
87- VLOG (3 ) << " find and remove an recv op: "
88- << recv_varname_to_ctx[recv_var_name];
89- }
54+ for (auto &node : graphs[0 ]->Nodes ()) {
55+ VLOG (3 ) << " node name " << node->Name ();
56+ if (node && node->IsOp ()) {
57+ if (node->Name () == " send" ) {
58+ auto send_var_name = node->Op ()->Input (" X" )[0 ];
59+ auto send_varnames = boost::get<std::vector<std::string>>(
60+ node->Op ()->GetNullableAttr (" send_varnames" ));
61+ auto epmap = boost::get<std::vector<std::string>>(
62+ node->Op ()->GetNullableAttr (" epmap" ));
63+ auto height_section = boost::get<std::vector<int64_t >>(
64+ node->Op ()->GetNullableAttr (" sections" ));
65+ auto trainer_id =
66+ boost::get<int >(node->Op ()->GetNullableAttr (" trainer_id" ));
67+ send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext (
68+ send_var_name, send_varnames, epmap, height_section, trainer_id);
69+ VLOG (3 ) << " find and init an send op: "
70+ << send_varname_to_ctx[send_var_name];
71+ } else if (node->Name () == " recv" ) {
72+ auto recv_var_name = node->Op ()->Output (" Out" )[0 ];
73+ auto recv_varnames = boost::get<std::vector<std::string>>(
74+ node->Op ()->GetNullableAttr (" recv_varnames" ));
75+ auto epmap = boost::get<std::vector<std::string>>(
76+ node->Op ()->GetNullableAttr (" epmap" ));
77+ auto trainer_id =
78+ boost::get<int >(node->Op ()->GetNullableAttr (" trainer_id" ));
79+ recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext (
80+ recv_var_name, recv_varnames, epmap, {}, trainer_id);
81+ VLOG (3 ) << " find and remove an recv op: "
82+ << recv_varname_to_ctx[recv_var_name];
9083 }
9184 }
9285 }
86+
9387 // init communicator here
9488 if (send_varname_to_ctx.size () > 0 ) {
9589 VLOG (3 ) << " this is distribute mode, will use communicator" ;
0 commit comments