Skip to content

Commit 0f4b6d6

Browse files
authored
Merge branch 'release/2.0' into rel/2.0-op-version-2
2 parents 580746f + e7cbc43 commit 0f4b6d6

File tree

177 files changed

+6082
-1461
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

177 files changed

+6082
-1461
lines changed

paddle/fluid/distributed/communicator_common.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ struct CommContext {
3030
const std::vector<int64_t> &sections,
3131
const std::vector<std::string> &origin_names, int id,
3232
bool merge_add_ = true, bool is_sparse_ = true,
33-
bool is_distributed_ = false, int table_id_ = -1)
33+
bool is_distributed_ = false, int table_id_ = -1,
34+
bool is_tensor_table_ = false)
3435
: var_name(name),
3536
splited_varnames(names),
3637
epmap(emap),
@@ -40,7 +41,8 @@ struct CommContext {
4041
merge_add(merge_add_),
4142
is_sparse(is_sparse_),
4243
is_distributed(is_distributed_),
43-
table_id(table_id_) {}
44+
table_id(table_id_),
45+
is_tensor_table(is_tensor_table_) {}
4446

4547
CommContext(const CommContext &ctx) {
4648
var_name = ctx.var_name;
@@ -53,6 +55,7 @@ struct CommContext {
5355
origin_varnames = ctx.origin_varnames;
5456
is_distributed = ctx.is_distributed;
5557
table_id = ctx.table_id;
58+
is_tensor_table = ctx.is_tensor_table;
5659
}
5760

5861
std::string print() const {
@@ -75,6 +78,7 @@ struct CommContext {
7578
ss << " is_sparse: " << is_sparse;
7679
ss << " is_distributed: " << is_distributed << "\n";
7780
ss << " table_id: " << table_id << "\n";
81+
ss << " is_tensor_table: " << is_tensor_table << "\n";
7882

7983
return ss.str();
8084
}
@@ -89,6 +93,7 @@ struct CommContext {
8993
bool is_sparse;
9094
bool is_distributed;
9195
int table_id;
96+
bool is_tensor_table;
9297
};
9398

9499
} // namespace distributed

paddle/fluid/distributed/fleet.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,16 @@ void FleetWrapper::LoadSparseOnServer(const std::string& path,
5353
pserver_ptr_->_server_ptr->table(table_id)->load(path, meta);
5454
}
5555

56-
void FleetWrapper::InitServer(const std::string& dist_desc,
57-
const std::vector<std::string>& host_sign_list,
58-
int index) {
56+
void FleetWrapper::InitServer(
57+
const std::string& dist_desc,
58+
const std::vector<std::string>& host_sign_list, int index,
59+
const std::vector<framework::ProgramDesc>& server_sub_program) {
5960
if (!is_initialized_) {
6061
VLOG(3) << "Going to init server";
6162
pserver_ptr_ = std::shared_ptr<paddle::distributed::PSCore>(
6263
new paddle::distributed::PSCore());
6364
pserver_ptr_->init_server(dist_desc, &host_sign_list, host_sign_list.size(),
64-
index);
65+
index, server_sub_program);
6566
is_initialized_ = true;
6667
} else {
6768
VLOG(3) << "Server can be initialized only once";

paddle/fluid/distributed/fleet.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,10 @@ class FleetWrapper {
154154
// init server
155155
// void InitServer(const std::string& dist_desc,
156156
// const std::vector<uint64_t>& host_sign_list, int index);
157-
void InitServer(const std::string& dist_desc,
158-
const std::vector<std::string>& host_sign_list, int index);
157+
void InitServer(
158+
const std::string& dist_desc,
159+
const std::vector<std::string>& host_sign_list, int index,
160+
const std::vector<framework::ProgramDesc>& server_sub_program = {});
159161
// init trainer
160162
void InitWorker(const std::string& dist_desc,
161163
const std::vector<std::string>& host_sign_list, Scope* scope,

paddle/fluid/distributed/ps.proto

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,11 @@ message TableAccessorParameter {
126126
}
127127

128128
message TensorAccessorParameter {
129-
optional string tensor_class = 1;
130-
optional uint32 fea_dim = 2;
131-
optional uint32 emb_dim = 3;
132-
optional string param = 4;
133-
optional string grad = 5;
134-
optional string common_block_map = 6;
129+
optional string feed_var_name = 1;
130+
optional string fetch_var_name = 2;
131+
optional int64 startup_program_id = 3;
132+
optional int64 main_program_id = 4;
133+
optional string tensor_table_class = 6;
135134
}
136135

137136
message CommonAccessorParameter {

paddle/fluid/distributed/service/brpc_ps_client.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,34 @@ std::future<int32_t> BrpcPsClient::push_dense_raw_gradient(
719719
return fut;
720720
}
721721

722+
std::future<int32_t> BrpcPsClient::push_global_step(int table_id,
723+
int64_t *total_send_data,
724+
void *done) {
725+
size_t request_call_num = _server_channels.size();
726+
DownpourBrpcClosure *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
727+
auto promise = std::make_shared<std::promise<int32_t>>();
728+
closure->add_promise(promise);
729+
std::future<int> fut = promise->get_future();
730+
for (size_t i = 0; i < request_call_num; ++i) {
731+
closure->request(i)->set_cmd_id(PS_PUSH_GLOBAL_STEP);
732+
closure->request(i)->set_table_id(table_id);
733+
closure->request(i)->set_client_id(_client_id);
734+
auto *push_data = closure->request(i)->mutable_data();
735+
push_data->clear();
736+
int32_t num_per_shard = 1;
737+
push_data->resize(sizeof(uint32_t) + num_per_shard * sizeof(int64_t));
738+
char *push_data_ptr = const_cast<char *>(push_data->data());
739+
memcpy(push_data_ptr, &num_per_shard, sizeof(uint32_t));
740+
memcpy(push_data_ptr + sizeof(uint32_t), total_send_data,
741+
num_per_shard * sizeof(int64_t));
742+
743+
PsService_Stub rpc_stub(get_dense_channel(i));
744+
rpc_stub.service(closure->cntl(i), closure->request(i),
745+
closure->response(i), closure);
746+
}
747+
return fut;
748+
}
749+
722750
std::future<int32_t> BrpcPsClient::pull_sparse(float **select_values,
723751
size_t table_id,
724752
const uint64_t *keys,

paddle/fluid/distributed/service/brpc_ps_client.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ class BrpcPsClient : public PSClient {
140140
std::vector<float> *values,
141141
std::vector<uint64_t> *keys,
142142
int pserver_idx);
143-
143+
virtual std::future<int32_t> push_global_step(int table_id,
144+
int64_t *total_send_data,
145+
void *done);
144146
virtual std::future<int32_t> flush();
145147

146148
virtual std::future<int32_t> send_client2client_msg(

paddle/fluid/distributed/service/brpc_ps_server.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ int32_t PsService::initialize() {
100100
_service_handler_map[PS_BARRIER] = &PsService::barrier;
101101
_service_handler_map[PS_START_PROFILER] = &PsService::start_profiler;
102102
_service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler;
103+
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &PsService::push_global_step;
103104

104105
// shard初始化,server启动后才可从env获取到server_list的shard信息
105106
initialize_shard_info();
@@ -526,5 +527,26 @@ int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request,
526527
return 0;
527528
}
528529

530+
int32_t PsService::push_global_step(Table *table,
531+
const PsRequestMessage &request,
532+
PsResponseMessage &response,
533+
brpc::Controller *cntl) {
534+
CHECK_TABLE_EXIST(table, request, response);
535+
auto req_buffer_size = request.data().size();
536+
if (req_buffer_size < 1) {
537+
set_response_code(response, 0, "run_program data is empty");
538+
return 0;
539+
}
540+
uint32_t num = *(const uint32_t *)(request.data().data());
541+
const int64_t *values =
542+
(const int64_t *)(request.data().data() + sizeof(uint32_t));
543+
auto trainer_id = request.client_id();
544+
if (table->push_dense(values, trainer_id) != 0) {
545+
set_response_code(response, -1, "run_program failed");
546+
}
547+
548+
return 0;
549+
}
550+
529551
} // namespace distributed
530552
} // namespace paddle

paddle/fluid/distributed/service/brpc_ps_server.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ class PsService : public PsBaseService {
110110
int32_t print_table_stat(Table *table, const PsRequestMessage &request,
111111
PsResponseMessage &response, brpc::Controller *cntl);
112112

113+
int32_t push_global_step(Table *table, const PsRequestMessage &request,
114+
PsResponseMessage &response, brpc::Controller *cntl);
115+
113116
bool _is_initialize_shard_info;
114117
std::mutex _initialize_shard_mutex;
115118
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;

paddle/fluid/distributed/service/communicator.cc

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ limitations under the License. */
3434
#include "paddle/fluid/string/printf.h"
3535
#include "paddle/fluid/string/split.h"
3636

37+
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
38+
#define STEP_COUNTER "@PS_STEP_COUNTER@"
39+
3740
namespace paddle {
3841
namespace distributed {
3942

@@ -377,6 +380,37 @@ void Communicator::RpcProfilerControl() {
377380
}
378381
}
379382

383+
void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
384+
Scope *send_scope) {
385+
if (batches == 0) {
386+
return;
387+
}
388+
auto &table_id = ctx.table_id;
389+
size_t request_call_num = _worker_ptr->get_server_nums();
390+
391+
auto &var_name = STEP_COUNTER;
392+
auto *out_var = send_scope->Var(var_name);
393+
auto *out_t = out_var->GetMutable<framework::LoDTensor>();
394+
auto *data = out_t->mutable_data<int64_t>({1}, platform::CPUPlace());
395+
data[0] = static_cast<int64_t>(batches);
396+
VLOG(3) << "Communicator::SendGlobalStep send: " << batches;
397+
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
398+
request_call_num, [this, request_call_num](void *done) {
399+
int ret = 0;
400+
auto *closure = (DownpourBrpcClosure *)done;
401+
for (size_t i = 0; i < request_call_num; ++i) {
402+
if (closure->check_response(i, PS_PUSH_GLOBAL_STEP) != 0) {
403+
ret = -1;
404+
break;
405+
}
406+
}
407+
closure->set_promise_value(ret);
408+
});
409+
auto status = _worker_ptr->push_global_step(table_id, data, closure);
410+
status.wait();
411+
return;
412+
}
413+
380414
void AsyncCommunicator::RecvThread() {
381415
if (!independent_recv_) return;
382416
VLOG(3) << "Independent RecvThread Start and Wait";
@@ -465,10 +499,16 @@ void AsyncCommunicator::SendByCommunicator() {
465499

466500
for (size_t i = 0; i < var_nums; i++) {
467501
auto &var_name = varnames[i];
468-
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
502+
if (var_name == STEP_COUNTER) {
503+
MergeVars<int64_t>(var_name, vars[i], send_scope_.get(), 1);
504+
} else {
505+
MergeVars<float>(var_name, vars[i], send_scope_.get(), 1);
506+
}
469507
}
470508

471-
if (ctx.is_sparse) {
509+
if (ctx.is_tensor_table) {
510+
SendGlobalStep(ctx, merged_var_num, send_scope_.get());
511+
} else if (ctx.is_sparse) {
472512
PADDLE_ENFORCE_EQ(
473513
varnames.size(), 1,
474514
platform::errors::InvalidArgument(
@@ -599,8 +639,18 @@ bool AsyncCommunicator::Check(const std::vector<std::string> &var_tables) {
599639
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
600640

601641
auto table_name = var_tables[0];
602-
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end())
642+
if (send_varname_to_ctx_.find(table_name) == send_varname_to_ctx_.end()) {
603643
return false;
644+
}
645+
if (table_name == STEP_COUNTER) {
646+
VLOG(3) << "send step_counter into queue";
647+
auto tmp_var = std::make_shared<Variable>();
648+
auto *tensor = tmp_var->GetMutable<framework::LoDTensor>();
649+
tensor->Resize(framework::make_ddim({1}));
650+
auto *out_d = tensor->mutable_data<int64_t>(platform::CPUPlace());
651+
out_d[0] = 1;
652+
send_varname_to_queue_[table_name]->Push(tmp_var);
653+
}
604654
return true;
605655
}
606656

paddle/fluid/distributed/service/communicator.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ class Communicator {
223223
// 6. recv sparse param
224224
virtual void RpcRecvSparse(const std::string &varname, int table_id,
225225
Scope *scope);
226+
// 7. send gloabl step
227+
virtual void SendGlobalStep(const CommContext &ctx, int batches,
228+
Scope *send_scope);
226229

227230
virtual ~Communicator() {}
228231
virtual void RpcProfilerControl();
@@ -376,8 +379,6 @@ class AsyncCommunicator : public Communicator {
376379

377380
virtual void SendByCommunicator();
378381

379-
virtual void SendGlobalStep(int batches) {}
380-
381382
virtual void RecvByCommunicator();
382383

383384
virtual void RecvNoBarrier();
@@ -527,8 +528,6 @@ class GeoCommunicator : public AsyncCommunicator {
527528

528529
void SendByCommunicator() { return; }
529530

530-
void SendGlobalStep(int batches) override { return; }
531-
532531
void RecvByCommunicator() override { return; }
533532

534533
inline std::string GradToParam(const std::string var_name) {

0 commit comments

Comments
 (0)