@@ -37,6 +37,8 @@ struct ExecutorPrepareContext;
3737} // namespace framework
3838} // namespace paddle
3939
40+ DECLARE_double (eager_delete_tensor_gb);
41+
4042namespace paddle {
4143namespace distributed {
4244
@@ -66,9 +68,9 @@ class TensorTable : public Table {
6668
6769 virtual void *get_shard (size_t shard_idx) { return 0 ; }
6870
69- virtual int32_t initialize_shard () { return 0 ; };
71+ virtual int32_t initialize_shard () { return 0 ; }
7072
71- virtual int32_t flush () { return 0 ; };
73+ virtual int32_t flush () { return 0 ; }
7274
7375 virtual int32_t load (const std::string &path, const std::string ¶m) {
7476 return 0 ;
@@ -77,18 +79,23 @@ class TensorTable : public Table {
7779 return 0 ;
7880 }
7981
80- virtual void clear (){};
82+ virtual void clear () {}
8183
82- virtual int32_t initialize () override { return 0 ; };
84+ int32_t initialize () override { return 0 ; }
8385
84- virtual int32_t push_dense (const int64_t *values,
85- const int32_t trainer_id) override {
86+ int32_t push_dense (const int64_t *values, const int32_t trainer_id) override {
8687 return 0 ;
87- };
88+ }
8889
89- virtual int32_t set_program_env (
90+ int32_t set_program_env (
9091 framework::Scope *scope, platform::Place place,
91- const std::vector<framework::ProgramDesc> *sub_program) override ;
92+ const std::vector<framework::ProgramDesc> *sub_program) override {
93+ scope_ = scope;
94+ place_ = place;
95+ executor_ = new framework::Executor (place_);
96+ sub_program_ = sub_program;
97+ return 0 ;
98+ }
9299
93100 protected:
94101 framework::Executor *executor_;
@@ -135,7 +142,7 @@ class DenseTensorTable : public TensorTable {
135142
136143 /* ----------------------------------------------------------------------*/
137144
138- virtual int32_t initialize () override { return 0 ; }
145+ int32_t initialize () override { return 0 ; }
139146
140147 int32_t push_dense (const float *values, size_t num) override { return 0 ; }
141148
@@ -189,18 +196,98 @@ class GlobalStepTable : public DenseTensorTable {
189196
190197 /* ----------------------------------------------------------------------*/
191198
192- int32_t initialize () override ;
199+ int32_t initialize () override {
200+ auto _program_config = _config.tensor ();
201+ auto trainers_ = _config.common ().trainer_num ();
202+ FLAGS_eager_delete_tensor_gb = -1 ;
203+ // Get Config
204+ if (_program_config.has_startup_program_id ()) {
205+ startup_program_id_ = _program_config.startup_program_id ();
206+ }
207+ if (_program_config.has_main_program_id ()) {
208+ main_program_id_ = _program_config.main_program_id ();
209+ }
210+ if (_program_config.has_feed_var_name ()) {
211+ feed_var_name_ = _program_config.feed_var_name ();
212+ }
213+ if (_program_config.has_fetch_var_name ()) {
214+ fetch_var_name_ = _program_config.fetch_var_name ();
215+ }
216+
217+ // Run startup program
218+ if (startup_program_id_ != -1 ) {
219+ std::map<std::string, const framework::LoDTensor *> fake_feed;
220+ std::map<std::string, framework::FetchType *> fake_fetch;
221+ auto startup_program_desc = sub_program_->at (startup_program_id_);
222+ auto ctx = executor_->Prepare (startup_program_desc, 0 );
223+ executor_->RunPreparedContext (ctx.get (), scope_, false );
224+ }
225+
226+ if (main_program_id_ != -1 ) {
227+ // Run main porgram, if program is used for learning decay
228+ auto main_program_desc = sub_program_->at (main_program_id_);
229+ auto main_ctx = executor_->Prepare (main_program_desc, 0 );
230+ exec_context_ = std::move (main_ctx);
231+ executor_->RunPreparedContext (exec_context_.get (), scope_, false );
232+ // init decay_counters
233+ decay_counters_.reserve (trainers_);
234+ for (int32_t i = 0 ; i < trainers_; ++i) {
235+ decay_counters_[i] = 0 ;
236+ }
237+ }
238+ }
193239
194240 int32_t push_dense (const float *values, size_t num) override { return 0 ; }
195241
196- int32_t push_dense (const int64_t *values, const int32_t trainer_id);
242+ int32_t push_dense (const int64_t *values, const int32_t trainer_id) {
243+ return _run_program (values, trainer_id);
244+ }
197245
198- int32_t set_table_map (
199- std::unordered_map<uint32_t , std::shared_ptr<Table>> *table_map) override ;
246+ int32_t set_table_map (std::unordered_map<uint32_t , std::shared_ptr<Table>>
247+ *table_map) override {
248+ auto *lr_var = scope_->FindVar (fetch_var_name_);
249+ auto *lr_tensor = lr_var->GetMutable <framework::LoDTensor>();
250+ auto *lr_value = lr_tensor->mutable_data <float >(platform::CPUPlace ());
251+ VLOG (3 ) << " GlobalStepTable::set_table_map set global lr: " << *lr_value;
252+
253+ for (auto iter = table_map->begin (); iter != table_map->end (); iter++) {
254+ auto table_id = iter->first ;
255+ if (table_id == _config.table_id ()) {
256+ continue ;
257+ }
258+ iter->second ->set_global_lr (lr_value);
259+ }
260+ return 0 ;
261+ }
200262
201263 private:
202264 virtual int32_t _run_program (const int64_t *values,
203- const uint32_t trainer_id);
265+ const uint32_t trainer_id) {
266+ FLAGS_eager_delete_tensor_gb = -1 ;
267+ auto counter = decay_counters_.at (trainer_id);
268+ counter += int (values[0 ]);
269+ decay_counters_.at (trainer_id) = counter;
270+
271+ auto *global_step_var = scope_->FindVar (feed_var_name_);
272+ auto *tensor = global_step_var->GetMutable <framework::LoDTensor>();
273+ auto *value = tensor->mutable_data <int64_t >(platform::CPUPlace ());
274+
275+ auto global_counter = 0 ;
276+ for (auto &trainer_counter : decay_counters_) {
277+ global_counter += trainer_counter.second ;
278+ }
279+
280+ // Todo: hard code for increment op
281+ value[0 ] = global_counter - 1 ;
282+ VLOG (3 ) << " GlobalStepTable::_run_program global_counter " << value[0 ];
283+
284+ executor_->RunPreparedContext (exec_context_.get (), scope_, false , false );
285+ auto *lr_var = scope_->FindVar (fetch_var_name_);
286+ auto *lr_tensor = lr_var->GetMutable <framework::LoDTensor>();
287+ auto *lr_value = lr_tensor->mutable_data <float >(platform::CPUPlace ());
288+ VLOG (3 ) << " GlobalStepTable::LR value: " << lr_value[0 ];
289+ return 0 ;
290+ }
204291
205292 private:
206293 std::unordered_map<int , int64_t > decay_counters_;
0 commit comments