Skip to content

Commit 60ff237

Browse files
committed
v0
1 parent 0ad9212 commit 60ff237

File tree

3 files changed

+132
-80
lines changed

3 files changed

+132
-80
lines changed

paddle/fluid/framework/blocking_queue.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ class BlockingQueue {
6464
return rc;
6565
}
6666

67+
size_t Size() {
68+
std::unique_lock<std::mutex> lock(mutex_);
69+
return q_.size();
70+
}
71+
6772
private:
6873
std::mutex mutex_;
6974
std::condition_variable cv_;

paddle/fluid/framework/details/threaded_ssa_graph_executor.cc

Lines changed: 116 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -28,100 +28,102 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2828
places_(places),
2929
fetch_ctxs_(places),
3030
running_ops_(0),
31-
strategy_(strategy) {}
31+
strategy_(strategy),
32+
thread_cnt_(strategy.num_threads_) {}
3233

33-
FeedFetchList ThreadedSSAGraphExecutor::Run(
34-
const std::vector<std::string> &fetch_tensors) {
35-
std::unordered_map<OpHandleBase *, size_t> pending_ops;
36-
std::unordered_set<VarHandleBase *> pending_vars;
37-
BlockingQueue<VarHandleBase *> ready_vars;
38-
std::unordered_set<OpHandleBase *> ready_ops;
39-
// For ops (e.g. nccl_all_reduce) that need to coordinate multiple
40-
// streams from multiple GPUs, it's faster to buffer them and schedule
41-
// together since we currently cannot overlap computation and memcpy streams.
42-
// Should revisit it if overlapping is available.
43-
std::unordered_set<OpHandleBase *> delayed_ops;
44-
45-
// Transform SSAGraph to pending_ops & pending_vars
46-
for (auto &var_map : graph_->vars_) {
47-
for (auto &name_pair : var_map) {
48-
for (auto &version_pair : name_pair.second) {
49-
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
34+
void ThreadedSSAGraphExecutor::RunOp(
35+
std::atomic<int> *total_ops, BlockingQueue<OpHandleBase *> *ready_ops,
36+
std::unordered_map<OpHandleBase *, std::atomic<int>> *pending_op_deps,
37+
details::OpHandleBase *op) {
38+
auto op_run = [ready_ops, pending_op_deps, op, total_ops, this] {
39+
OpHandleBase *current_op = op;
40+
while (true) {
41+
// 1. If current_op is nullptr, get a runnable op from pending_ops
42+
if (current_op == nullptr) {
43+
if (*total_ops <= 0) break;
44+
current_op = ready_ops->Pop();
5045
}
51-
}
52-
}
53-
for (auto &var : graph_->dep_vars_) {
54-
InsertPendingVar(&pending_vars, &ready_vars, var.get());
55-
}
5646

57-
for (auto &op : graph_->ops_) {
58-
if (op->Inputs().empty()) { // Special case, Op has no input.
59-
ready_ops.insert(op.get());
60-
} else {
61-
InsertPendingOp(&pending_ops, op.get());
47+
// 2. Run the current op
48+
try {
49+
VLOG(10) << current_op << " " << current_op->Name() << " : "
50+
<< current_op->DebugString();
51+
current_op->Run(strategy_.use_event_);
52+
VLOG(10) << current_op << " " << current_op->Name() << " Done ";
53+
} catch (platform::EnforceNotMet ex) {
54+
exception_.reset(new platform::EnforceNotMet(ex));
55+
} catch (...) {
56+
LOG(FATAL) << "Unknown exception catched";
57+
}
58+
total_ops->fetch_sub(1);
59+
auto released_vars = current_op->Outputs();
60+
61+
// 3. Decrease the dependency of pending_op_deps according to
62+
// released_vars. And find the runnable op.
63+
current_op = nullptr;
64+
for (auto ready_var : released_vars) {
65+
for (auto *op : ready_var->pending_ops_) {
66+
auto dep_num = pending_op_deps->at(op).fetch_sub(1);
67+
if (dep_num == 0) {
68+
if (op->IsMultiDeviceTransfer() && strategy_.allow_op_delay_) {
69+
ready_ops->Push(op);
70+
} else {
71+
if (!current_op) {
72+
current_op = op;
73+
}
74+
}
75+
}
76+
}
77+
}
6278
}
79+
};
80+
81+
if (pool_) {
82+
pool_->enqueue(op_run);
83+
} else {
84+
op_run();
6385
}
86+
}
6487

65-
// Step 2. Insert FetchOps
88+
FeedFetchList ThreadedSSAGraphExecutor::Run(
89+
const std::vector<std::string> &fetch_tensors) {
90+
// Step 1. Insert FetchOps
6691
std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
6792
std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
6893
FeedFetchList fetch_data(fetch_tensors.size());
6994

70-
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
71-
&pending_vars, &ready_vars, &fetch_data);
95+
InsertFetchOps(fetch_tensors, &fetch_ops, &fetch_dependencies, &fetch_data);
7296

73-
auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
74-
for (auto *op : set) {
75-
running_ops_++;
76-
RunOp(&ready_vars, op);
77-
}
78-
set.clear();
79-
};
97+
// Step 2. Collect ready_ops and pending_op_deps
98+
BlockingQueue<OpHandleBase *> ready_ops; // read and write
99+
std::unordered_map<OpHandleBase *, std::atomic<int>>
100+
pending_op_deps; // only read
80101

81-
// Step 3. Execution
82-
while (!pending_vars.empty()) {
83-
// 1. Run All Ready ops
84-
// Keep loop until all vars are ready.
85-
//
86-
// NOTE: DelayedOps have a lower priority. It will be scheduled after all
87-
// ready_ops have been performed.
88-
if (ready_ops.empty() && strategy_.allow_op_delay_ && running_ops_ == 0) {
89-
run_all_ops(delayed_ops);
102+
for (auto &op : graph_->ops_) {
103+
if (op->Inputs().empty()) {
104+
ready_ops.Push(op.get());
90105
} else {
91-
run_all_ops(ready_ops);
106+
pending_op_deps.insert({op.get(), op->NoDupInputSize()});
92107
}
108+
}
109+
for (auto &op : fetch_ops) {
110+
pending_op_deps.insert({op.get(), op->NoDupInputSize()});
111+
}
93112

94-
// 2. Find ready variable
95-
bool timeout;
96-
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
97-
98-
if (timeout) {
99-
if (exception_) {
100-
auto exp = *exception_;
101-
exception_.reset();
102-
throw exp;
103-
} else {
104-
continue;
105-
}
106-
}
107-
// 3. Remove the dependency of ready_var.
108-
// Find the ready_ops after the ready_var.
109-
for (auto ready_var : cur_ready_vars) {
110-
pending_vars.erase(ready_var);
111-
for (auto *op : ready_var->pending_ops_) {
112-
auto &deps = pending_ops[op];
113-
--deps;
114-
if (deps == 0) {
115-
if (op->IsMultiDeviceTransfer() && strategy_.allow_op_delay_) {
116-
delayed_ops.insert(op);
117-
} else {
118-
ready_ops.insert(op);
119-
}
120-
}
121-
}
122-
}
113+
// according to total_ops to know whether the loop is over
114+
std::atomic<int> total_ops(
115+
static_cast<int>(graph_->ops_.size() + fetch_ops.size()));
116+
117+
// Step 3. Execution
118+
for (size_t i = 0; i < thread_cnt_; ++i) {
119+
RunOp(&total_ops, &ready_ops, &pending_op_deps, nullptr);
123120
}
124-
PADDLE_ENFORCE(ready_ops.empty());
121+
122+
// while (true) {
123+
// if (total_ops == 0) break;
124+
// }
125+
126+
PADDLE_ENFORCE(total_ops == 0);
125127

126128
// Wait FetchOps.
127129
if (!fetch_ops.empty()) {
@@ -131,6 +133,42 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
131133
return fetch_data;
132134
}
133135

136+
void ThreadedSSAGraphExecutor::InsertFetchOps(
137+
const std::vector<std::string> &fetch_tensors,
138+
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
139+
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
140+
FeedFetchList *fetch_data) {
141+
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
142+
143+
for (auto &fetch_var_name : fetch_tensors) {
144+
for (auto &var_map : graph_->vars_) {
145+
auto it = var_map.find(fetch_var_name);
146+
if (it != var_map.end()) {
147+
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());
148+
}
149+
}
150+
}
151+
152+
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
153+
auto &var_name = fetch_tensors[i];
154+
auto &vars = fetched_vars.at(var_name);
155+
auto *op = new FetchOpHandle(fetch_data, i, &local_scopes_);
156+
fetch_ops->emplace_back(op);
157+
158+
for (auto &p : places_) {
159+
op->SetDeviceContext(p, fetch_ctxs_.Get(p));
160+
}
161+
162+
for (auto *var : vars) {
163+
op->AddInput(var);
164+
}
165+
166+
auto *fetch_dummy = new DummyVarHandle();
167+
op->AddOutput(fetch_dummy);
168+
fetch_dependencies->emplace(fetch_dummy);
169+
}
170+
}
171+
134172
void ThreadedSSAGraphExecutor::InsertFetchOps(
135173
const std::vector<std::string> &fetch_tensors,
136174
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,

paddle/fluid/framework/details/threaded_ssa_graph_executor.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
4949
private:
5050
void RunOp(BlockingQueue<VarHandleBase *> *ready_var_q,
5151
details::OpHandleBase *op);
52+
void RunOp(
53+
std::atomic<int> *total_ops, BlockingQueue<OpHandleBase *> *pending_ops,
54+
std::unordered_map<OpHandleBase *, std::atomic<int>> *pending_op_deps,
55+
details::OpHandleBase *current_op);
5256

5357
private:
5458
std::unique_ptr<::ThreadPool> pool_;
@@ -57,6 +61,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
5761
platform::DeviceContextPool fetch_ctxs_;
5862
std::unique_ptr<platform::EnforceNotMet> exception_;
5963
std::atomic<int> running_ops_;
64+
ExecutionStrategy strategy_;
65+
size_t thread_cnt_;
6066

6167
void InsertPendingOp(std::unordered_map<OpHandleBase *, size_t> *pending_ops,
6268
OpHandleBase *op_instance) const;
@@ -73,8 +79,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
7379
std::unordered_set<VarHandleBase *> *pending_vars,
7480
BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data);
7581

76-
private:
77-
ExecutionStrategy strategy_;
82+
void InsertFetchOps(
83+
const std::vector<std::string> &fetch_tensors,
84+
std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
85+
std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
86+
FeedFetchList *fetch_data);
7887
};
7988

8089
} // namespace details

0 commit comments

Comments
 (0)