1313// limitations under the License.
1414
1515#include " paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
16+ #include " paddle/fluid/framework/threadpool.h"
1617
1718namespace paddle {
1819namespace framework {
@@ -33,55 +34,50 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
3334
3435void ThreadedSSAGraphExecutor::RunOp (
3536 std::atomic<int > *total_ops, BlockingQueue<OpHandleBase *> *ready_ops,
36- std::unordered_map<OpHandleBase *, std::atomic<int >> *pending_op_deps,
37- details::OpHandleBase *op) {
38- auto op_run = [ready_ops, pending_op_deps, op, total_ops, this ] {
39- OpHandleBase *current_op = op;
40- while (true ) {
41- // 1. If current_op is nullptr, get a runnable op from pending_ops
42- if (current_op == nullptr ) {
43- if (*total_ops <= 0 ) break ;
44- current_op = ready_ops->Pop ();
45- }
37+ std::unordered_map<OpHandleBase *, std::atomic<size_t >> *pending_op_deps) {
38+ bool timeout;
39+ // std::deque<OpHandleBase *> local_ops;
40+ OpHandleBase *current_op = nullptr ;
41+
42+ while (true ) {
43+ // 1. If current_op is nullptr, get a runnable op from pending_ops.
44+ if (current_op == nullptr ) {
45+ if ((*total_ops) <= 0 ) break ;
46+ current_op = ready_ops->Pop (1 , &timeout);
47+ if (timeout) continue ;
48+ }
4649
47- // 2. Run the current op
48- try {
49- VLOG (10 ) << current_op << " " << current_op->Name () << " : "
50- << current_op->DebugString ();
51- current_op->Run (strategy_.use_event_ );
52- VLOG (10 ) << current_op << " " << current_op->Name () << " Done " ;
53- } catch (platform::EnforceNotMet ex) {
54- exception_.reset (new platform::EnforceNotMet (ex));
55- } catch (...) {
56- LOG (FATAL) << " Unknown exception catched" ;
57- }
58- total_ops->fetch_sub (1 );
59- auto released_vars = current_op->Outputs ();
60-
61- // 3. Decrease the dependency of pending_op_deps according to
62- // released_vars. And find the runnable op.
63- current_op = nullptr ;
64- for (auto ready_var : released_vars) {
65- for (auto *op : ready_var->pending_ops_ ) {
66- auto dep_num = pending_op_deps->at (op).fetch_sub (1 );
67- if (dep_num == 0 ) {
68- if (op->IsMultiDeviceTransfer () && strategy_.allow_op_delay_ ) {
69- ready_ops->Push (op);
70- } else {
71- if (!current_op) {
72- current_op = op;
73- }
74- }
50+ // 2. Run the current op.
51+ try {
52+ VLOG (10 ) << current_op << " " << current_op->Name () << " : "
53+ << current_op->DebugString ();
54+ current_op->Run (strategy_.use_event_ );
55+ --(*total_ops);
56+ VLOG (10 ) << current_op << " " << current_op->Name () << " Done " ;
57+ } catch (platform::EnforceNotMet ex) {
58+ exception_.reset (new platform::EnforceNotMet (ex));
59+ } catch (...) {
60+ LOG (FATAL) << " Unknown exception catched" ;
61+ }
62+ auto released_vars = current_op->Outputs ();
63+
64+ // 3. Decrease the dependency of pending_op_deps. And find the runnable op.
65+ current_op = nullptr ;
66+ for (auto ready_var : released_vars) {
67+ for (auto *op : ready_var->pending_ops_ ) {
68+ auto dep_num = --pending_op_deps->at (op);
69+ if (dep_num == 0 ) {
70+ bool push_into_ready_ops =
71+ current_op != nullptr ||
72+ (op->IsMultiDeviceTransfer () && strategy_.allow_op_delay_ );
73+ if (push_into_ready_ops) {
74+ ready_ops->Push (op);
75+ } else {
76+ current_op = op;
7577 }
7678 }
7779 }
7880 }
79- };
80-
81- if (pool_) {
82- pool_->enqueue (op_run);
83- } else {
84- op_run ();
8581 }
8682}
8783
@@ -96,34 +92,65 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
9692
9793 // Step 2. Collect ready_ops and pending_op_deps
9894 BlockingQueue<OpHandleBase *> ready_ops; // read and write
99- std::unordered_map<OpHandleBase *, std::atomic<int >>
95+ std::unordered_map<OpHandleBase *, std::atomic<size_t >>
10096 pending_op_deps; // only read
10197
10298 for (auto &op : graph_->ops_ ) {
10399 if (op->Inputs ().empty ()) {
104100 ready_ops.Push (op.get ());
105101 } else {
106- pending_op_deps. insert ({ op.get (), op->NoDupInputSize ()} );
102+ pending_op_deps[ op.get ()] = op->NoDupInputSize ();
107103 }
108104 }
109105 for (auto &op : fetch_ops) {
110- pending_op_deps.insert ({op.get (), op->NoDupInputSize ()});
106+ pending_op_deps[op.get ()] = op->NoDupInputSize ();
107+ }
108+
109+ // move some pending op to ready ops
110+ for (auto &var_map : graph_->vars_ ) {
111+ for (auto &name_pair : var_map) {
112+ for (auto &version_pair : name_pair.second ) {
113+ if (version_pair->generated_op_ == nullptr ) {
114+ for (auto pending_op : version_pair->pending_ops_ ) {
115+ --pending_op_deps[pending_op];
116+ if (pending_op_deps[pending_op] == 0 ) {
117+ ready_ops.Push (pending_op);
118+ }
119+ }
120+ }
121+ }
122+ }
123+ }
124+
125+ for (auto &var : graph_->dep_vars_ ) {
126+ if (var->generated_op_ == nullptr ) {
127+ for (auto pending_op : var->pending_ops_ ) {
128+ --pending_op_deps[pending_op];
129+ if (pending_op_deps[pending_op] == 0 ) {
130+ ready_ops.Push (pending_op);
131+ }
132+ }
133+ }
111134 }
112135
113136 // according to total_ops to know whether the loop is over
114137 std::atomic<int > total_ops (
115138 static_cast <int >(graph_->ops_ .size () + fetch_ops.size ()));
116139
117140 // Step 3. Execution
141+ std::vector<std::thread> workers;
142+ workers.resize (thread_cnt_);
118143 for (size_t i = 0 ; i < thread_cnt_; ++i) {
119- RunOp (&total_ops, &ready_ops, &pending_op_deps, nullptr );
144+ workers[i] = std::thread ([&total_ops, &ready_ops, &pending_op_deps, this ] {
145+ RunOp (&total_ops, &ready_ops, &pending_op_deps);
146+ });
120147 }
121148
122- // while (true ) {
123- // if (total_ops == 0) break ;
124- // }
149+ for ( auto &worker : workers ) {
150+ worker. join () ;
151+ }
125152
126- PADDLE_ENFORCE (total_ops = = 0 );
153+ PADDLE_ENFORCE (total_ops < = 0 );
127154
128155 // Wait FetchOps.
129156 if (!fetch_ops.empty ()) {
@@ -169,46 +196,6 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
169196 }
170197}
171198
172- void ThreadedSSAGraphExecutor::InsertFetchOps (
173- const std::vector<std::string> &fetch_tensors,
174- std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
175- std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
176- std::unordered_map<OpHandleBase *, size_t > *pending_ops,
177- std::unordered_set<VarHandleBase *> *pending_vars,
178- BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
179- std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
180-
181- for (auto &fetch_var_name : fetch_tensors) {
182- for (auto &var_map : graph_->vars_ ) {
183- auto it = var_map.find (fetch_var_name);
184- if (it != var_map.end ()) {
185- fetched_vars[fetch_var_name].push_back (it->second .rbegin ()->get ());
186- }
187- }
188- }
189-
190- for (size_t i = 0 ; i < fetch_tensors.size (); ++i) {
191- auto &var_name = fetch_tensors[i];
192- auto &vars = fetched_vars.at (var_name);
193- auto *op = new FetchOpHandle (fetch_data, i, &local_scopes_);
194- fetch_ops->emplace_back (op);
195-
196- for (auto &p : places_) {
197- op->SetDeviceContext (p, fetch_ctxs_.Get (p));
198- }
199-
200- for (auto *var : vars) {
201- op->AddInput (var);
202- }
203-
204- auto *fetch_dummy = new DummyVarHandle ();
205- op->AddOutput (fetch_dummy);
206- fetch_dependencies->emplace (fetch_dummy);
207- this ->InsertPendingVar (pending_vars, ready_vars, fetch_dummy);
208- this ->InsertPendingOp (pending_ops, op);
209- }
210- }
211-
212199void ThreadedSSAGraphExecutor::InsertPendingOp (
213200 std::unordered_map<OpHandleBase *, size_t > *pending_ops,
214201 OpHandleBase *op_instance) const {
0 commit comments