@@ -37,20 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
3737 const std::string &loss_var_name,
3838 const std::unordered_set<std::string> ¶ms,
3939 const std::vector<Scope *> &local_scopes,
40- platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale)
40+ platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
41+ bool balance_parameter_opt_between_cards)
4142 : loss_var_name_(loss_var_name),
4243 places_(places),
4344 local_scopes_(local_scopes),
44- nccl_ctxs_(nccl_ctxs) {
45+ nccl_ctxs_(nccl_ctxs),
46+ balance_parameter_opt_between_cards_(
47+ balance_parameter_opt_between_cards) {
4548#else
4649MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder (
4750 const std::vector<platform::Place> &places,
4851 const std::string &loss_var_name,
4952 const std::unordered_set<std::string> ¶ms,
50- const std::vector<Scope *> &local_scopes, bool use_default_grad_scale)
53+ const std::vector<Scope *> &local_scopes, bool use_default_grad_scale,
54+ bool balance_parameter_opt_between_cards)
5155 : loss_var_name_ (loss_var_name),
5256 places_ (places),
53- local_scopes_ (local_scopes) {
57+ local_scopes_ (local_scopes),
58+ balance_parameter_opt_between_cards_ (
59+ balance_parameter_opt_between_cards) {
5460#endif
5561 for (auto &p : params) {
5662 grad_names_.insert (GradVarName (p));
@@ -124,6 +130,12 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
124130 // Find "send" op first for split is in front of send.
125131 OpDesc *send_op = GetSendOpDesc (program);
126132
133+ size_t cur_device_id = 0 ;
134+ std::vector<std::unordered_set<std::string>> var_name_on_devices;
135+ std::vector<std::unordered_set<std::string>> bcast_var_name_set;
136+ var_name_on_devices.resize (places_.size ());
137+ bcast_var_name_set.resize (places_.size ());
138+
127139 bool is_forwarding = true ;
128140 for (auto *op : program.Block (0 ).AllOps ()) {
129141 if (op->Type () == " send" ) {
@@ -139,24 +151,47 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
139151 }
140152 is_forwarding = false ;
141153 } else {
142- CreateComputationalOps (&result, *op, places_.size ());
154+ int op_dev_id = GetOpDeviceID (var_name_on_devices, *op);
155+ if (op_dev_id == -1 ) { // var on all device
156+ CreateComputationalOps (&result, *op, places_.size ());
157+ } else {
158+ CreateComputationalOp (&result, *op, op_dev_id);
159+ for (auto &var_name : op->OutputArgumentNames ()) {
160+ var_name_on_devices[op_dev_id].emplace (var_name);
161+ }
162+ }
143163 if (!is_forwarding && places_.size () > 1 ) {
144164 // Currently, we assume that once gradient is generated, it can be
145165 // broadcast, and each gradient is only broadcast once.
146166 for (auto &og : op->OutputArgumentNames ()) {
147167 if (IsParameterGradientOnce (og, &og_has_been_broadcast)) {
148- if (IsSparseGradient (var_types, og)) {
149- CreateReduceOp (&result, og, 0 );
150- CreateBroadcastOp (&result, og, 0 );
168+ if (balance_parameter_opt_between_cards_) {
169+ CreateReduceOp (&result, og, cur_device_id);
170+ var_name_on_devices[cur_device_id].emplace (og);
171+ bcast_var_name_set[cur_device_id].emplace (
172+ og.substr (0 , og.size () - strlen (kGradVarSuffix )));
173+ cur_device_id = (cur_device_id + 1 ) % places_.size ();
151174 } else {
152- InsertNCCLAllReduceOp (&result, og);
175+ if (IsSparseGradient (var_types, og)) {
176+ CreateReduceOp (&result, og, 0 );
177+ CreateBroadcastOp (&result, og, 0 );
178+ } else {
179+ InsertNCCLAllReduceOp (&result, og);
180+ }
153181 }
154182 }
155183 }
156184 }
157185 }
158186 }
159187
188+ // Insert BCast Ops
189+ for (size_t dev_id = 0 ; dev_id < bcast_var_name_set.size (); ++dev_id) {
190+ auto &to_bcast_set = bcast_var_name_set[dev_id];
191+ for (auto &bcast_name : to_bcast_set) {
192+ CreateBroadcastOp (&result, bcast_name, dev_id);
193+ }
194+ }
160195 /*
161196 Dependency graph has been constructed. However, there are still data
162197 harzaeds need to be handled.
@@ -265,6 +300,26 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
265300 return is_pg_once;
266301}
267302
303+ int MultiDevSSAGraphBuilder::GetOpDeviceID (
304+ const std::vector<std::unordered_set<std::string>> &var_name_on_devices,
305+ const OpDesc &op) const {
306+ if (!balance_parameter_opt_between_cards_) {
307+ return -1 ;
308+ }
309+
310+ int var_dev_id = -1 ;
311+ for (auto &var_name : op.InputArgumentNames ()) {
312+ if (var_dev_id != -1 ) break ;
313+ for (size_t i = 0 ; i < var_name_on_devices.size (); ++i) {
314+ if (var_name_on_devices[i].count (var_name)) {
315+ var_dev_id = static_cast <int >(i);
316+ break ;
317+ }
318+ }
319+ }
320+ return var_dev_id;
321+ }
322+
268323void MultiDevSSAGraphBuilder::CreateScaleLossGradOp (SSAGraph *result) const {
269324 for (size_t i = 0 ; i < places_.size (); ++i) {
270325// Insert ScaleCost OpHandle
0 commit comments