1313// limitations under the License.
1414
1515#include " paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
16+ #include " paddle/fluid/framework/threadpool.h"
1617
1718namespace paddle {
1819namespace framework {
@@ -33,55 +34,49 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
3334
3435void ThreadedSSAGraphExecutor::RunOp (
3536 std::atomic<int > *total_ops, BlockingQueue<OpHandleBase *> *ready_ops,
36- std::unordered_map<OpHandleBase *, std::atomic<int >> *pending_op_deps,
37- details::OpHandleBase *op) {
38- auto op_run = [ready_ops, pending_op_deps, op, total_ops, this ] {
39- OpHandleBase *current_op = op;
40- while (true ) {
41- // 1. If current_op is nullptr, get a runnable op from pending_ops
42- if (current_op == nullptr ) {
43- if (*total_ops <= 0 ) break ;
44- current_op = ready_ops->Pop ();
45- }
37+ std::unordered_map<OpHandleBase *, std::atomic<size_t >> *pending_op_deps) {
38+ bool timeout;
39+ OpHandleBase *current_op = nullptr ;
40+
41+ while (true ) {
42+ // 1. If current_op is nullptr, get a runnable op from ready_ops.
43+ if (current_op == nullptr ) {
44+ if ((*total_ops) <= 0 ) break ;
45+ current_op = ready_ops->Pop (1 , &timeout);
46+ if (timeout) continue ;
47+ }
4648
47- // 2. Run the current op
48- try {
49- VLOG (10 ) << current_op << " " << current_op->Name () << " : "
50- << current_op->DebugString ();
51- current_op->Run (strategy_.use_event_ );
52- VLOG (10 ) << current_op << " " << current_op->Name () << " Done " ;
53- } catch (platform::EnforceNotMet ex) {
54- exception_.reset (new platform::EnforceNotMet (ex));
55- } catch (...) {
56- LOG (FATAL) << " Unknown exception catched" ;
57- }
58- total_ops->fetch_sub (1 );
59- auto released_vars = current_op->Outputs ();
60-
61- // 3. Decrease the dependency of pending_op_deps according to
62- // released_vars. And find the runnable op.
63- current_op = nullptr ;
64- for (auto ready_var : released_vars) {
65- for (auto *op : ready_var->pending_ops_ ) {
66- auto dep_num = pending_op_deps->at (op).fetch_sub (1 );
67- if (dep_num == 0 ) {
68- if (op->IsMultiDeviceTransfer () && strategy_.allow_op_delay_ ) {
69- ready_ops->Push (op);
70- } else {
71- if (!current_op) {
72- current_op = op;
73- }
74- }
49+ // 2. Run the current op.
50+ try {
51+ VLOG (10 ) << current_op << " " << current_op->Name () << " : "
52+ << current_op->DebugString ();
53+ current_op->Run (strategy_.use_event_ );
54+ --(*total_ops);
55+ VLOG (10 ) << current_op << " " << current_op->Name () << " Done " ;
56+ } catch (platform::EnforceNotMet ex) {
57+ exception_.reset (new platform::EnforceNotMet (ex));
58+ } catch (...) {
59+ LOG (FATAL) << " Unknown exception catched" ;
60+ }
61+ auto released_vars = current_op->Outputs ();
62+
63+ // 3. Decrease the dependency of pending_op_deps. And find the runnable op.
64+ current_op = nullptr ;
65+ for (auto ready_var : released_vars) {
66+ for (auto *op : ready_var->pending_ops_ ) {
67+ auto dep_num = --pending_op_deps->at (op);
68+ if (dep_num == 0 ) {
69+ bool push_into_ready_ops =
70+ current_op != nullptr ||
71+ (op->IsMultiDeviceTransfer () && strategy_.allow_op_delay_ );
72+ if (push_into_ready_ops) {
73+ ready_ops->Push (op);
74+ } else {
75+ current_op = op;
7576 }
7677 }
7778 }
7879 }
79- };
80-
81- if (pool_) {
82- pool_->enqueue (op_run);
83- } else {
84- op_run ();
8580 }
8681}
8782
@@ -95,35 +90,64 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
9590 InsertFetchOps (fetch_tensors, &fetch_ops, &fetch_dependencies, &fetch_data);
9691
9792 // Step 2. Collect ready_ops and pending_op_deps
98- BlockingQueue<OpHandleBase *> ready_ops; // read and write
99- std::unordered_map<OpHandleBase *, std::atomic<int >>
100- pending_op_deps; // only read
93+ BlockingQueue<OpHandleBase *> ready_ops;
94+ std::unordered_map<OpHandleBase *, std::atomic<size_t >> pending_op_deps;
10195
10296 for (auto &op : graph_->ops_ ) {
10397 if (op->Inputs ().empty ()) {
10498 ready_ops.Push (op.get ());
10599 } else {
106- pending_op_deps. insert ({ op.get (), op->NoDupInputSize ()} );
100+ pending_op_deps[ op.get ()] = op->NoDupInputSize ();
107101 }
108102 }
109103 for (auto &op : fetch_ops) {
110- pending_op_deps.insert ({op.get (), op->NoDupInputSize ()});
104+ pending_op_deps[op.get ()] = op->NoDupInputSize ();
105+ }
106+
107+ auto insert_ready_ops = [&ready_ops, &pending_op_deps](VarHandleBase *op) {
108+ if (op->generated_op_ == nullptr ) {
109+ for (auto pending_op : op->pending_ops_ ) {
110+ --pending_op_deps[pending_op];
111+ if (pending_op_deps[pending_op] == 0 ) {
112+ ready_ops.Push (pending_op);
113+ }
114+ }
115+ }
116+ };
117+
118+ // Insert_ready_ops
119+ for (auto &var_map : graph_->vars_ ) {
120+ for (auto &name_pair : var_map) {
121+ for (auto &version_pair : name_pair.second ) {
122+ insert_ready_ops (version_pair.get ());
123+ }
124+ }
125+ }
126+
127+ for (auto &var : graph_->dep_vars_ ) {
128+ if (var->generated_op_ == nullptr ) {
129+ insert_ready_ops (var.get ());
130+ }
111131 }
112132
113133 // according to total_ops to know whether the loop is over
114134 std::atomic<int > total_ops (
115135 static_cast <int >(graph_->ops_ .size () + fetch_ops.size ()));
116136
117137 // Step 3. Execution
138+ std::vector<std::thread> workers;
139+ workers.resize (thread_cnt_);
118140 for (size_t i = 0 ; i < thread_cnt_; ++i) {
119- RunOp (&total_ops, &ready_ops, &pending_op_deps, nullptr );
141+ workers[i] = std::thread ([&total_ops, &ready_ops, &pending_op_deps, this ] {
142+ RunOp (&total_ops, &ready_ops, &pending_op_deps);
143+ });
120144 }
121145
122- // while (true ) {
123- // if (total_ops == 0) break ;
124- // }
146+ for ( auto &worker : workers ) {
147+ worker. join () ;
148+ }
125149
126- PADDLE_ENFORCE (total_ops = = 0 );
150+ PADDLE_ENFORCE (total_ops < = 0 );
127151
128152 // Wait FetchOps.
129153 if (!fetch_ops.empty ()) {
@@ -169,46 +193,6 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
169193 }
170194}
171195
172- void ThreadedSSAGraphExecutor::InsertFetchOps (
173- const std::vector<std::string> &fetch_tensors,
174- std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
175- std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
176- std::unordered_map<OpHandleBase *, size_t > *pending_ops,
177- std::unordered_set<VarHandleBase *> *pending_vars,
178- BlockingQueue<VarHandleBase *> *ready_vars, FeedFetchList *fetch_data) {
179- std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
180-
181- for (auto &fetch_var_name : fetch_tensors) {
182- for (auto &var_map : graph_->vars_ ) {
183- auto it = var_map.find (fetch_var_name);
184- if (it != var_map.end ()) {
185- fetched_vars[fetch_var_name].push_back (it->second .rbegin ()->get ());
186- }
187- }
188- }
189-
190- for (size_t i = 0 ; i < fetch_tensors.size (); ++i) {
191- auto &var_name = fetch_tensors[i];
192- auto &vars = fetched_vars.at (var_name);
193- auto *op = new FetchOpHandle (fetch_data, i, &local_scopes_);
194- fetch_ops->emplace_back (op);
195-
196- for (auto &p : places_) {
197- op->SetDeviceContext (p, fetch_ctxs_.Get (p));
198- }
199-
200- for (auto *var : vars) {
201- op->AddInput (var);
202- }
203-
204- auto *fetch_dummy = new DummyVarHandle ();
205- op->AddOutput (fetch_dummy);
206- fetch_dependencies->emplace (fetch_dummy);
207- this ->InsertPendingVar (pending_vars, ready_vars, fetch_dummy);
208- this ->InsertPendingOp (pending_ops, op);
209- }
210- }
211-
212196void ThreadedSSAGraphExecutor::InsertPendingOp (
213197 std::unordered_map<OpHandleBase *, size_t > *pending_ops,
214198 OpHandleBase *op_instance) const {
0 commit comments