@@ -34,6 +34,9 @@ limitations under the License. */
3434#include  " paddle/fluid/string/printf.h" 
3535#include  " paddle/fluid/string/split.h" 
3636
37+ #define  LEARNING_RATE_DECAY_COUNTER  " @LR_DECAY_COUNTER@" 
38+ #define  STEP_COUNTER  " @PS_STEP_COUNTER@" 
39+ 
3740namespace  paddle  {
3841namespace  distributed  {
3942
@@ -377,6 +380,37 @@ void Communicator::RpcProfilerControl() {
377380  }
378381}
379382
383+ void  Communicator::SendGlobalStep (const  CommContext &ctx, int  batches,
384+                                   Scope *send_scope) {
385+   if  (batches == 0 ) {
386+     return ;
387+   }
388+   auto  &table_id = ctx.table_id ;
389+   size_t  request_call_num = _worker_ptr->get_server_nums ();
390+ 
391+   auto  &var_name = STEP_COUNTER;
392+   auto  *out_var = send_scope->Var (var_name);
393+   auto  *out_t  = out_var->GetMutable <framework::LoDTensor>();
394+   auto  *data = out_t ->mutable_data <int64_t >({1 }, platform::CPUPlace ());
395+   data[0 ] = static_cast <int64_t >(batches);
396+   VLOG (3 ) << " Communicator::SendGlobalStep send: " 
397+   DownpourBrpcClosure *closure = new  DownpourBrpcClosure (
398+       request_call_num, [this , request_call_num](void  *done) {
399+         int  ret = 0 ;
400+         auto  *closure = (DownpourBrpcClosure *)done;
401+         for  (size_t  i = 0 ; i < request_call_num; ++i) {
402+           if  (closure->check_response (i, PS_PUSH_GLOBAL_STEP) != 0 ) {
403+             ret = -1 ;
404+             break ;
405+           }
406+         }
407+         closure->set_promise_value (ret);
408+       });
409+   auto  status = _worker_ptr->push_global_step (table_id, data, closure);
410+   status.wait ();
411+   return ;
412+ }
413+ 
380414void  AsyncCommunicator::RecvThread () {
381415  if  (!independent_recv_) return ;
382416  VLOG (3 ) << " Independent RecvThread Start and Wait" 
@@ -465,10 +499,16 @@ void AsyncCommunicator::SendByCommunicator() {
465499
466500      for  (size_t  i = 0 ; i < var_nums; i++) {
467501        auto  &var_name = varnames[i];
468-         MergeVars<float >(var_name, vars[i], send_scope_.get (), 1 );
502+         if  (var_name == STEP_COUNTER) {
503+           MergeVars<int64_t >(var_name, vars[i], send_scope_.get (), 1 );
504+         } else  {
505+           MergeVars<float >(var_name, vars[i], send_scope_.get (), 1 );
506+         }
469507      }
470508
471-       if  (ctx.is_sparse ) {
509+       if  (ctx.is_tensor_table ) {
510+         SendGlobalStep (ctx, merged_var_num, send_scope_.get ());
511+       } else  if  (ctx.is_sparse ) {
472512        PADDLE_ENFORCE_EQ (
473513            varnames.size (), 1 ,
474514            platform::errors::InvalidArgument (
@@ -599,8 +639,18 @@ bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
599639      platform::errors::InvalidArgument (" var_tables.size() == 1 is permitted" 
600640
601641  auto  table_name = var_tables[0 ];
602-   if  (send_varname_to_ctx_.find (table_name) == send_varname_to_ctx_.end ())
642+   if  (send_varname_to_ctx_.find (table_name) == send_varname_to_ctx_.end ()) { 
603643    return  false ;
644+   }
645+   if  (table_name == STEP_COUNTER) {
646+     VLOG (3 ) << " send step_counter into queue" 
647+     auto  tmp_var = std::make_shared<Variable>();
648+     auto  *tensor = tmp_var->GetMutable <framework::LoDTensor>();
649+     tensor->Resize (framework::make_ddim ({1 }));
650+     auto  *out_d = tensor->mutable_data <int64_t >(platform::CPUPlace ());
651+     out_d[0 ] = 1 ;
652+     send_varname_to_queue_[table_name]->Push (tmp_var);
653+   }
604654  return  true ;
605655}
606656
0 commit comments