@@ -28,100 +28,102 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
2828 places_(places),
2929 fetch_ctxs_(places),
3030 running_ops_(0 ),
31- strategy_(strategy) {}
31+ strategy_(strategy),
32+ thread_cnt_(strategy.num_threads_) {}
3233
33- FeedFetchList ThreadedSSAGraphExecutor::Run (
34- const std::vector<std::string> &fetch_tensors) {
35- std::unordered_map<OpHandleBase *, size_t > pending_ops;
36- std::unordered_set<VarHandleBase *> pending_vars;
37- BlockingQueue<VarHandleBase *> ready_vars;
38- std::unordered_set<OpHandleBase *> ready_ops;
39- // For ops (e.g. nccl_all_reduce) that need to coordinate multiple
40- // streams from multiple GPUs, it's faster to buffer them and schedule
41- // together since we currently cannot overlap computation and memcpy streams.
42- // Should revisit it if overlapping is available.
43- std::unordered_set<OpHandleBase *> delayed_ops;
44-
45- // Transform SSAGraph to pending_ops & pending_vars
46- for (auto &var_map : graph_->vars_ ) {
47- for (auto &name_pair : var_map) {
48- for (auto &version_pair : name_pair.second ) {
49- InsertPendingVar (&pending_vars, &ready_vars, version_pair.get ());
34+ void ThreadedSSAGraphExecutor::RunOp (
35+ std::atomic<int > *total_ops, BlockingQueue<OpHandleBase *> *ready_ops,
36+ std::unordered_map<OpHandleBase *, std::atomic<int >> *pending_op_deps,
37+ details::OpHandleBase *op) {
38+ auto op_run = [ready_ops, pending_op_deps, op, total_ops, this ] {
39+ OpHandleBase *current_op = op;
40+ while (true ) {
41+ // 1. If current_op is nullptr, get a runnable op from pending_ops
42+ if (current_op == nullptr ) {
43+ if (*total_ops <= 0 ) break ;
44+ current_op = ready_ops->Pop ();
5045 }
51- }
52- }
53- for (auto &var : graph_->dep_vars_ ) {
54- InsertPendingVar (&pending_vars, &ready_vars, var.get ());
55- }
5646
57- for (auto &op : graph_->ops_ ) {
58- if (op->Inputs ().empty ()) { // Special case, Op has no input.
59- ready_ops.insert (op.get ());
60- } else {
61- InsertPendingOp (&pending_ops, op.get ());
47+ // 2. Run the current op
48+ try {
49+ VLOG (10 ) << current_op << " " << current_op->Name () << " : "
50+ << current_op->DebugString ();
51+ current_op->Run (strategy_.use_event_ );
52+ VLOG (10 ) << current_op << " " << current_op->Name () << " Done " ;
53+ } catch (platform::EnforceNotMet ex) {
54+ exception_.reset (new platform::EnforceNotMet (ex));
55+ } catch (...) {
56+ LOG (FATAL) << " Unknown exception catched" ;
57+ }
58+ total_ops->fetch_sub (1 );
59+ auto released_vars = current_op->Outputs ();
60+
61+ // 3. Decrease the dependency of pending_op_deps according to
62+ // released_vars. And find the runnable op.
63+ current_op = nullptr ;
64+ for (auto ready_var : released_vars) {
65+ for (auto *op : ready_var->pending_ops_ ) {
66+ auto dep_num = pending_op_deps->at (op).fetch_sub (1 );
67+ if (dep_num == 0 ) {
68+ if (op->IsMultiDeviceTransfer () && strategy_.allow_op_delay_ ) {
69+ ready_ops->Push (op);
70+ } else {
71+ if (!current_op) {
72+ current_op = op;
73+ }
74+ }
75+ }
76+ }
77+ }
6278 }
79+ };
80+
81+ if (pool_) {
82+ pool_->enqueue (op_run);
83+ } else {
84+ op_run ();
6385 }
86+ }
6487
65- // Step 2. Insert FetchOps
88+ FeedFetchList ThreadedSSAGraphExecutor::Run (
89+ const std::vector<std::string> &fetch_tensors) {
90+ // Step 1. Insert FetchOps
6691 std::vector<std::unique_ptr<FetchOpHandle>> fetch_ops;
6792 std::unordered_set<std::unique_ptr<VarHandleBase>> fetch_dependencies;
6893 FeedFetchList fetch_data (fetch_tensors.size ());
6994
70- InsertFetchOps (fetch_tensors, &fetch_ops, &fetch_dependencies, &pending_ops,
71- &pending_vars, &ready_vars, &fetch_data);
95+ InsertFetchOps (fetch_tensors, &fetch_ops, &fetch_dependencies, &fetch_data);
7296
73- auto run_all_ops = [&](std::unordered_set<OpHandleBase *> &set) {
74- for (auto *op : set) {
75- running_ops_++;
76- RunOp (&ready_vars, op);
77- }
78- set.clear ();
79- };
97+ // Step 2. Collect ready_ops and pending_op_deps
98+ BlockingQueue<OpHandleBase *> ready_ops; // read and write
99+ std::unordered_map<OpHandleBase *, std::atomic<int >>
100+ pending_op_deps; // only read
80101
81- // Step 3. Execution
82- while (!pending_vars.empty ()) {
83- // 1. Run All Ready ops
84- // Keep loop until all vars are ready.
85- //
86- // NOTE: DelayedOps have a lower priority. It will be scheduled after all
87- // ready_ops have been performed.
88- if (ready_ops.empty () && strategy_.allow_op_delay_ && running_ops_ == 0 ) {
89- run_all_ops (delayed_ops);
102+ for (auto &op : graph_->ops_ ) {
103+ if (op->Inputs ().empty ()) {
104+ ready_ops.Push (op.get ());
90105 } else {
91- run_all_ops (ready_ops );
106+ pending_op_deps. insert ({op. get (), op-> NoDupInputSize ()} );
92107 }
108+ }
109+ for (auto &op : fetch_ops) {
110+ pending_op_deps.insert ({op.get (), op->NoDupInputSize ()});
111+ }
93112
94- // 2. Find ready variable
95- bool timeout;
96- auto cur_ready_vars = ready_vars.PopAll (1 , &timeout);
97-
98- if (timeout) {
99- if (exception_) {
100- auto exp = *exception_;
101- exception_.reset ();
102- throw exp;
103- } else {
104- continue ;
105- }
106- }
107- // 3. Remove the dependency of ready_var.
108- // Find the ready_ops after the ready_var.
109- for (auto ready_var : cur_ready_vars) {
110- pending_vars.erase (ready_var);
111- for (auto *op : ready_var->pending_ops_ ) {
112- auto &deps = pending_ops[op];
113- --deps;
114- if (deps == 0 ) {
115- if (op->IsMultiDeviceTransfer () && strategy_.allow_op_delay_ ) {
116- delayed_ops.insert (op);
117- } else {
118- ready_ops.insert (op);
119- }
120- }
121- }
122- }
113+ // according to total_ops to know whether the loop is over
114+ std::atomic<int > total_ops (
115+ static_cast <int >(graph_->ops_ .size () + fetch_ops.size ()));
116+
117+ // Step 3. Execution
118+ for (size_t i = 0 ; i < thread_cnt_; ++i) {
119+ RunOp (&total_ops, &ready_ops, &pending_op_deps, nullptr );
123120 }
124- PADDLE_ENFORCE (ready_ops.empty ());
121+
122+ // while (true) {
123+ // if (total_ops == 0) break;
124+ // }
125+
126+ PADDLE_ENFORCE (total_ops == 0 );
125127
126128 // Wait FetchOps.
127129 if (!fetch_ops.empty ()) {
@@ -131,6 +133,42 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
131133 return fetch_data;
132134}
133135
136+ void ThreadedSSAGraphExecutor::InsertFetchOps (
137+ const std::vector<std::string> &fetch_tensors,
138+ std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
139+ std::unordered_set<std::unique_ptr<VarHandleBase>> *fetch_dependencies,
140+ FeedFetchList *fetch_data) {
141+ std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
142+
143+ for (auto &fetch_var_name : fetch_tensors) {
144+ for (auto &var_map : graph_->vars_ ) {
145+ auto it = var_map.find (fetch_var_name);
146+ if (it != var_map.end ()) {
147+ fetched_vars[fetch_var_name].push_back (it->second .rbegin ()->get ());
148+ }
149+ }
150+ }
151+
152+ for (size_t i = 0 ; i < fetch_tensors.size (); ++i) {
153+ auto &var_name = fetch_tensors[i];
154+ auto &vars = fetched_vars.at (var_name);
155+ auto *op = new FetchOpHandle (fetch_data, i, &local_scopes_);
156+ fetch_ops->emplace_back (op);
157+
158+ for (auto &p : places_) {
159+ op->SetDeviceContext (p, fetch_ctxs_.Get (p));
160+ }
161+
162+ for (auto *var : vars) {
163+ op->AddInput (var);
164+ }
165+
166+ auto *fetch_dummy = new DummyVarHandle ();
167+ op->AddOutput (fetch_dummy);
168+ fetch_dependencies->emplace (fetch_dummy);
169+ }
170+ }
171+
134172void ThreadedSSAGraphExecutor::InsertFetchOps (
135173 const std::vector<std::string> &fetch_tensors,
136174 std::vector<std::unique_ptr<FetchOpHandle>> *fetch_ops,
0 commit comments