@@ -107,6 +107,10 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
107107
108108std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build (
109109 const ProgramDesc &program) const {
110+ std::unordered_map<std::string, proto::VarType::Type> var_types;
111+ for (auto *var : program.Block (0 ).AllVars ()) {
112+ var_types[var->Name ()] = var->GetType ();
113+ }
110114 auto graph = new SSAGraph ();
111115 SSAGraph &result = *graph;
112116 std::unordered_set<std::string> og_has_been_broadcast;
@@ -116,7 +120,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
116120 std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
117121 places_.size ());
118122
119- size_t cur_update_sparse_gp_dev_id = 0 ;
123+ size_t cur_dev_id = 0 ;
120124 std::vector<std::unordered_set<std::string>> sparse_var_name_on_devices;
121125 std::vector<std::unordered_set<std::string>> bcast_sparse_var_name_set;
122126
@@ -156,14 +160,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
156160 // broadcast, and each gradient is only broadcast once.
157161 for (auto &og : op->OutputArgumentNames ()) {
158162 if (IsParameterGradientOnce (og, &og_has_been_broadcast)) {
159- if (IsSparseGradient (og)) {
160- CreateReduceOp (&result, cur_update_sparse_gp_dev_id, og);
161- sparse_var_name_on_devices[cur_update_sparse_gp_dev_id].emplace (
162- og);
163- bcast_sparse_var_name_set[cur_update_sparse_gp_dev_id].emplace (
163+ if (IsSparseGradient (var_types, og)) {
164+ CreateReduceOp (&result, cur_dev_id, og);
165+ sparse_var_name_on_devices[cur_dev_id].emplace (og);
166+ bcast_sparse_var_name_set[cur_dev_id].emplace (
164167 og.substr (0 , og.size () - strlen (kGradVarSuffix )));
165- cur_update_sparse_gp_dev_id =
166- (cur_update_sparse_gp_dev_id + 1 ) % places_.size ();
168+ cur_dev_id = (cur_dev_id + 1 ) % places_.size ();
167169 } else {
168170 InsertNCCLAllReduceOp (&result, og);
169171 }
@@ -201,10 +203,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
201203 return std::unique_ptr<SSAGraph>(graph);
202204}
203205
204- bool MultiDevSSAGraphBuilder::IsSparseGradient (const std::string &og) const {
205- auto og_var = local_scopes_[0 ]->FindVar (og);
206- PADDLE_ENFORCE_NOT_NULL (og_var);
207- return og_var->IsType <SelectedRows>();
206+ bool MultiDevSSAGraphBuilder::IsSparseGradient (
207+ const std::unordered_map<std::string, proto::VarType::Type> &var_types,
208+ const std::string &og) const {
209+ PADDLE_ENFORCE (var_types.count (og) != 0 );
210+ if (var_types.at (og) == proto::VarType::SELECTED_ROWS) {
211+ return true ;
212+ }
213+ return false ;
208214}
209215
210216int MultiDevSSAGraphBuilder::GetOpDeviceID (
0 commit comments