Skip to content

Commit 053ecd6

Browse files
committed
Refine CPU version for ParallelExecutor
1 parent efaed58 commit 053ecd6

File tree

6 files changed

+28
-36
lines changed

6 files changed

+28
-36
lines changed

paddle/fluid/framework/details/all_reduce_op_handle.cc

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,16 @@ AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
2727
const std::vector<platform::Place> &places,
2828
const platform::NCCLContextMap *ctxs)
2929
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
30-
use_cuda_ = false;
3130
if (nccl_ctxs_) {
3231
for (auto &p : places_) {
3332
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
3433
}
35-
use_cuda_ = true;
3634
}
3735
}
3836
#else
3937
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
4038
const std::vector<platform::Place> &places)
41-
: local_scopes_(local_scopes), places_(places) {
42-
use_cuda_ = false;
43-
}
39+
: local_scopes_(local_scopes), places_(places) {}
4440
#endif
4541

4642
void AllReduceOpHandle::RunImpl() {
@@ -117,28 +113,18 @@ void AllReduceOpHandle::RunImpl() {
117113
// Reduce All Tensor to trg in CPU
118114
ReduceLoDTensor func(lod_tensors, &trg);
119115
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
120-
bool use_cuda = use_cuda_;
116+
121117
for (size_t i = 1; i < local_scopes_.size(); ++i) {
122118
auto &scope =
123119
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
124120
auto &p = places_[i];
125121
auto *var = scope.FindVar(out_var_handles[i]->name_);
126122
auto *dev_ctx = dev_ctxes_[p];
127123

128-
RunAndRecordEvent(p, [&trg, var, dev_ctx, p, use_cuda] {
129-
#ifdef PADDLE_WITH_CUDA
130-
if (use_cuda) {
131-
auto &tensor_dst = *var->GetMutable<framework::LoDTensor>();
132-
auto &tensor_src = trg;
133-
TensorCopy(tensor_src, p, *dev_ctx, &tensor_dst);
134-
} else {
135-
auto &tensor_dst = *var->GetMutable<framework::LoDTensor>();
136-
tensor_dst.ShareDataWith(trg);
137-
}
138-
#else
139-
auto &tensor_dst = *var->GetMutable<framework::LoDTensor>();
140-
tensor_dst.ShareDataWith(trg);
141-
#endif
124+
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
125+
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
126+
auto &tensor_cpu = trg;
127+
TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu);
142128
});
143129
}
144130
}

paddle/fluid/framework/details/all_reduce_op_handle.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ struct AllReduceOpHandle : public OpHandleBase {
5252
#ifdef PADDLE_WITH_CUDA
5353
const platform::NCCLContextMap *nccl_ctxs_;
5454
#endif
55-
bool use_cuda_;
5655
};
5756

5857
} // namespace details

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,21 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
260260
}
261261
}
262262

263-
// Insert BCast Ops
264-
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
265-
auto &to_bcast_set = bcast_var_name_set[dev_id];
266-
for (auto &bcast_name : to_bcast_set) {
267-
CreateBroadcastOp(&result, bcast_name, dev_id);
263+
bool use_gpu = false;
264+
#ifdef PADDLE_WITH_CUDA
265+
use_gpu = nccl_ctxs_ != nullptr;
266+
#endif
267+
268+
if (use_gpu) {
269+
// Insert BCast Ops
270+
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
271+
auto &to_bcast_set = bcast_var_name_set[dev_id];
272+
for (auto &bcast_name : to_bcast_set) {
273+
CreateBroadcastOp(&result, bcast_name, dev_id);
274+
}
268275
}
269276
}
277+
270278
/*
271279
Dependency graph has been constructed. However, there are still data
272280
hazards need to be handled.

paddle/fluid/framework/parallel_executor.cc

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ ParallelExecutor::ParallelExecutor(
9595
}
9696

9797
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
98-
BCastParamsToGPUs(bcast_vars, member_->use_cuda_);
98+
BCastParamsToGPUs(bcast_vars);
9999
}
100100
// Startup Program has been run. All local scopes has correct parameters.
101101

@@ -132,9 +132,7 @@ ParallelExecutor::ParallelExecutor(
132132
}
133133

134134
void ParallelExecutor::BCastParamsToGPUs(
135-
const std::unordered_set<std::string> &vars, const bool use_cuda) const {
136-
auto *main_scope = member_->local_scopes_[0];
137-
135+
const std::unordered_set<std::string> &vars) const {
138136
// the the initialize bcast, all vars would be bcast from device(0), otherwise
139137
// bcast from the specified device.
140138
bool initialize = builder_.get() == nullptr ? true : false;
@@ -156,12 +154,11 @@ void ParallelExecutor::BCastParamsToGPUs(
156154
}
157155

158156
auto &main_tensor = main_var->Get<LoDTensor>();
159-
#ifdef PADDLE_WITH_CUDA
160-
auto &dims = main_tensor.dims();
161-
#endif
157+
162158
if (paddle::platform::is_gpu_place(main_tensor.place())) {
163159
#ifdef PADDLE_WITH_CUDA
164160
std::vector<void *> buffers;
161+
auto &dims = main_tensor.dims();
165162
size_t numel = main_tensor.numel();
166163
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
167164
for (size_t i = 0; i < member_->places_.size(); ++i) {
@@ -200,7 +197,8 @@ void ParallelExecutor::BCastParamsToGPUs(
200197
auto local_scope = member_->local_scopes_[i];
201198
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
202199
#ifdef PADDLE_WITH_CUDA
203-
if (use_cuda) {
200+
if (member_->use_cuda_) {
201+
auto &dims = main_tensor.dims();
204202
t->Resize(dims);
205203
t->mutable_data(cpu, main_tensor.type());
206204
paddle::framework::TensorCopy(main_tensor, cpu, t);

paddle/fluid/framework/parallel_executor.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ class ParallelExecutor {
6666
void Run(const std::vector<std::string> &fetch_tensors,
6767
const std::string &fetched_var_name);
6868

69-
void BCastParamsToGPUs(const std::unordered_set<std::string> &vars,
70-
const bool use_cuda) const;
69+
void BCastParamsToGPUs(const std::unordered_set<std::string> &vars) const;
7170

7271
private:
7372
ParallelExecutorPrivate *member_;

python/paddle/fluid/tests/unittests/parallel_executor_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def run_executor(exe, feed, fetch_list, program=None):
4545
raise ValueError('Unkown type exe')
4646
return res
4747

48+
if not use_cuda:
49+
balance_parameter_opt_between_cards = True
4850
main = fluid.Program()
4951
startup = fluid.Program()
5052
startup.random_seed = 1 # Fix random seed

0 commit comments

Comments
 (0)