Skip to content

Commit 86b0a72

Browse files
author
chengduo
authored
Refine multi thread cpu parallel exe (#11406)
* refine multi-thread CPU Parallel exe * refine multi thread CPU Parallel exe * Refine CPU version for ParallelExecutor * add share_parameter_between_cards_ * Fix ParallelExecutor bug * Fix unit test * Fix parameter opt balance * Fix with opti (param->grad) * Add grad to op var * Remove shard_param_between_cards
1 parent 76086df commit 86b0a72

File tree

9 files changed

+121
-83
lines changed

9 files changed

+121
-83
lines changed

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,22 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
276276
}
277277
}
278278

279-
// Insert BCast Ops
280-
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
281-
auto &to_bcast_set = bcast_var_name_set[dev_id];
282-
for (auto &bcast_name : to_bcast_set) {
283-
CreateBroadcastOp(&result, bcast_name, dev_id);
279+
bool use_gpu = false;
280+
#ifdef PADDLE_WITH_CUDA
281+
use_gpu = nccl_ctxs_ != nullptr;
282+
#endif
283+
284+
if (use_gpu ||
285+
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
286+
// Insert BCast Ops
287+
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
288+
auto &to_bcast_set = bcast_var_name_set[dev_id];
289+
for (auto &bcast_name : to_bcast_set) {
290+
CreateBroadcastOp(&result, bcast_name, dev_id);
291+
}
284292
}
285293
}
294+
286295
/*
287296
Dependency graph has been constructed. However, there are still data
288297
hazards need to be handled.
@@ -412,14 +421,19 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const OpDesc &op) const {
412421
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
413422
return -1;
414423
}
415-
416-
for (auto &varname : op.InputArgumentNames()) {
417-
int dev_id = GetVarDeviceID(varname);
418-
if (dev_id != -1) {
419-
return dev_id;
420-
}
424+
int op_role = boost::get<int>(
425+
op.GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
426+
if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
427+
return -1;
421428
}
422-
return -1;
429+
auto param_grad = boost::get<std::vector<std::string>>(
430+
op.GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
431+
432+
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
433+
int dev_id = GetVarDeviceID(param_grad[1]);
434+
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", op.Type(),
435+
param_grad[0]);
436+
return dev_id;
423437
}
424438

425439
int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {

paddle/fluid/framework/parallel_executor.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ParallelExecutorPrivate {
4545
#endif
4646
bool own_local_scope_;
4747
bool use_cuda_;
48+
bool use_all_reduce_;
4849
};
4950

5051
std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
@@ -62,6 +63,14 @@ ParallelExecutor::ParallelExecutor(
6263
: member_(new ParallelExecutorPrivate(places)) {
6364
member_->global_scope_ = scope;
6465
member_->use_cuda_ = exec_strategy.use_cuda_;
66+
member_->use_all_reduce_ =
67+
build_strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce;
68+
69+
if (!member_->use_all_reduce_) {
70+
PADDLE_ENFORCE(places.size() > 1,
71+
"If you set build_strategy.reduce with 'Reduce',"
72+
"the number of places must be greater than 1.");
73+
}
6574

6675
// Step 1. Bcast the params to devs.
6776
// Create local scopes
@@ -117,7 +126,7 @@ ParallelExecutor::ParallelExecutor(
117126
#ifdef PADDLE_WITH_CUDA
118127
builder_factory.SetNCCLContextMap(member_->nccl_ctxs_.get());
119128
#else
120-
PADDLE_THROW("Not compiled with CUDA");
129+
PADDLE_THROW("Not compiled with CUDA.");
121130
#endif
122131
}
123132

@@ -133,7 +142,7 @@ ParallelExecutor::ParallelExecutor(
133142

134143
void ParallelExecutor::BCastParamsToDevs(
135144
const std::unordered_set<std::string> &vars) const {
136-
// the the initializing bcast, all vars would be bcast from device(0),
145+
// the initializing bcast, all vars would be bcast from device(0),
137146
// otherwise
138147
// bcast from the specified device.
139148
bool initializing = builder_.get() == nullptr ? true : false;
@@ -209,9 +218,13 @@ void ParallelExecutor::BCastParamsToDevs(
209218

210219
auto local_scope = member_->local_scopes_[i];
211220
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
212-
t->Resize(dims);
213-
t->mutable_data(cpu, main_tensor.type());
214-
paddle::framework::TensorCopy(main_tensor, cpu, t);
221+
if (member_->use_all_reduce_ || member_->use_cuda_) {
222+
t->Resize(dims);
223+
t->mutable_data(cpu, main_tensor.type());
224+
paddle::framework::TensorCopy(main_tensor, cpu, t);
225+
} else {
226+
t->ShareDataWith(main_tensor);
227+
}
215228
}
216229
}
217230
}

python/paddle/fluid/clip.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,12 @@ def set_gradient_clip(clip, param_list=None, program=None):
324324
param.gradient_clip_attr = copy.deepcopy(clip)
325325

326326

327-
def append_gradient_clip_ops(param_grad):
327+
def append_gradient_clip_ops(param_grads):
328328
context = dict()
329-
for p, g in param_grad:
330-
with p.block.program.optimized_guard(p):
329+
for p, g in param_grads:
330+
if g is None:
331+
continue
332+
with p.block.program.optimized_guard([p, g]):
331333
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
332334
if clip_attr is None:
333335
clip_attr = NullGradientClipAttr()
@@ -339,8 +341,10 @@ def append_gradient_clip_ops(param_grad):
339341
clip_attr._process_context(context=context, param=p, grad=g)
340342

341343
res = []
342-
for p, g in param_grad:
343-
with p.block.program.optimized_guard(p):
344+
for p, g in param_grads:
345+
if g is None:
346+
continue
347+
with p.block.program.optimized_guard([p, g]):
344348
res.append(clip_attr._create_operators(param=p, grad=g))
345349

346350
return res

python/paddle/fluid/framework.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,25 +1319,28 @@ def set_op_role_var(self, var_name):
13191319
self._op_role_var = [var_name]
13201320

13211321
@contextlib.contextmanager
1322-
def optimized_guard(self, var):
1322+
def optimized_guard(self, param_and_grads):
13231323
"""
13241324
A with guard to set :code:`Optimization` :code:`OpRole` and
13251325
:code:`OpRoleVar` automatically.
13261326
13271327
Notes: This is a very low level API. Users should not use it directly.
13281328
13291329
Args:
1330-
var(Variable|str): The variable (name) to be optimized.
1330+
param_and_grads(list): The variables (names) to be optimized.
13311331
13321332
Examples:
13331333
13341334
>>> p, g = backward(...)
1335-
>>> with program.optimized_guard(p):
1335+
>>> with program.optimized_guard([p,g]):
13361336
>>> p = p - 0.001 * g
13371337
"""
13381338
OpRole = core.op_proto_and_checker_maker.OpRole
13391339
self._current_role = OpRole.Optimize
1340-
self._op_role_var = [var.name if isinstance(var, Variable) else var]
1340+
self._op_role_var = [
1341+
var.name if isinstance(var, Variable) else var
1342+
for var in param_and_grads
1343+
]
13411344
yield
13421345
self._op_role_var = []
13431346
self._current_role = OpRole.Forward

python/paddle/fluid/optimizer.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _create_accumulators(self, block, parameters):
123123
"""
124124
pass
125125

126-
def _finish_update(self, block, parameters):
126+
def _finish_update(self, block, parameters_and_grads):
127127
"""Finish any custom updates needed
128128
before completing an optimization step
129129
@@ -226,18 +226,18 @@ def _create_optimization_pass(self,
226226

227227
optimize_ops = []
228228
for param_and_grad in parameters_and_grads:
229+
if param_and_grad[1] is None:
230+
continue
229231
with param_and_grad[0].block.program.optimized_guard(
230-
param_and_grad[0]):
231-
if param_and_grad[0].trainable is True and param_and_grad[
232-
1] is not None:
232+
param_and_grad):
233+
if param_and_grad[0].trainable is True:
233234
optimize_op = self._append_optimize_op(loss.block,
234235
param_and_grad)
235236
optimize_ops.append(optimize_op)
236237

237238
# Get custom finish ops for subclasses
238239
# FIXME: Need to fix this once we figure out how to handle dependencies
239-
self._finish_update(loss.block,
240-
[p[0] for p in parameters_and_grads])
240+
self._finish_update(loss.block, parameters_and_grads)
241241

242242
end = len(global_block.ops)
243243
return global_block.slice_ops(start, end)
@@ -564,13 +564,15 @@ def _append_optimize_op(self, block, param_and_grad):
564564

565565
return adam_op
566566

567-
def _finish_update(self, block, parameters):
567+
def _finish_update(self, block, param_and_grads):
568568
"""Update Beta1 and Beta2 Power accumulators
569569
"""
570570
assert isinstance(block, framework.Block)
571571
main_block = block.program.global_block()
572-
for param in parameters:
573-
with param.block.program.optimized_guard(param):
572+
for param, grad in param_and_grads:
573+
if grad is None:
574+
continue
575+
with param.block.program.optimized_guard([param, grad]):
574576
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
575577
param)
576578
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
@@ -691,13 +693,15 @@ def _append_optimize_op(self, block, param_and_grad):
691693

692694
return adamax_op
693695

694-
def _finish_update(self, block, parameters):
696+
def _finish_update(self, block, parameters_and_grads):
695697
"""Update Beta1 Power accumulator
696698
"""
697699
assert isinstance(block, framework.Block)
698700
main_block = block.program.global_block()
699-
for param in parameters:
700-
with param.block.program.optimized_guard(param):
701+
for param, grad in parameters_and_grads:
702+
if grad is None:
703+
continue
704+
with param.block.program.optimized_guard([param, grad]):
701705
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
702706
param)
703707
main_block.append_op(
@@ -1158,7 +1162,9 @@ def __init__(self,
11581162
self.params_grads.append((param, grad))
11591163

11601164
for param, grad in self.params_grads:
1161-
with param.block.program.optimized_guard(param):
1165+
if grad is None:
1166+
continue
1167+
with param.block.program.optimized_guard([param, grad]):
11621168
self._append_average_accumulate_op(param)
11631169

11641170
self.apply_program = Program()

python/paddle/fluid/regularizer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,11 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
4141
"""
4242
params_and_grads = []
4343
for param, grad in parameters_and_grads:
44-
with param.block.program.optimized_guard(param):
45-
# If no gradient then we don't need to do anything
46-
if grad is None:
47-
params_and_grads.append((param, grad))
48-
continue
49-
44+
# If no gradient then we don't need to do anything
45+
if grad is None:
46+
params_and_grads.append((param, grad))
47+
continue
48+
with param.block.program.optimized_guard([param, grad]):
5049
regularization_term = None
5150
if param.regularizer is not None:
5251
# Add variable for regularization term in grad block

python/paddle/fluid/tests/unittests/parallel_executor_test_base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def check_network_convergence(self,
3535
feed_dict=None,
3636
seed=None,
3737
use_parallel_executor=True,
38-
balance_parameter_opt_between_cards=False):
38+
use_reduce=False):
3939
def run_executor(exe, feed, fetch_list, program=None):
4040
if isinstance(exe, fluid.ParallelExecutor):
4141
res = exe.run(fetch_list=fetch_list, feed=feed)
@@ -50,22 +50,28 @@ def run_executor(exe, feed, fetch_list, program=None):
5050
main = fluid.Program()
5151
startup = fluid.Program()
5252
startup.random_seed = 1 # Fix random seed
53+
main.random_seed = 1
5354
with fluid.program_guard(main, startup):
5455
if seed is not None:
5556
startup.random_seed = seed
57+
main.random_seed = seed
58+
5659
loss = method(use_feed=feed_dict is not None)
5760
adam = fluid.optimizer.Adam()
5861
adam.minimize(loss)
62+
5963
if memory_opt:
6064
fluid.memory_optimize(main)
65+
6166
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
6267
startup_exe = fluid.Executor(place)
6368
startup_exe.run(startup)
6469
exec_strategy = fluid.ExecutionStrategy()
6570
exec_strategy.allow_op_delay = allow_op_delay
6671

6772
build_strategy = fluid.BuildStrategy()
68-
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce if balance_parameter_opt_between_cards else fluid.BuildStrategy.ReduceStrategy.AllReduce
73+
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
74+
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
6975

7076
if use_parallel_executor:
7177
exe = fluid.ParallelExecutor(

0 commit comments

Comments
 (0)