Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_per_server;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample(
std::future<int32_t> GraphBrpcClient::batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
std::vector<int> request2server;
Expand Down Expand Up @@ -68,7 +68,8 @@ std::future<int32_t> GraphBrpcClient::batch_sample(
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
if (closure->check_response(request_idx,
PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
++fail_num;
} else {
auto &res_io_buffer =
Expand Down Expand Up @@ -113,7 +114,7 @@ std::future<int32_t> GraphBrpcClient::batch_sample(

for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE);
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
// std::string type_str = GraphNode::node_type_to_string(type);
Expand All @@ -132,7 +133,49 @@ std::future<int32_t> GraphBrpcClient::batch_sample(

return fut;
}

std::future<int32_t> GraphBrpcClient::random_sample_nodes(
uint32_t table_id, int server_index, int sample_size,
std::vector<uint64_t> &ids) {
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) {
ret = -1;
} else {
// VLOG(0) << "check sample response: "
// << " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char buffer[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
// GraphNode node;
// node.recover_from_buffer(buffer + index);
// index += node.get_size(true);
// res.push_back(node);
ids.push_back(*(uint64_t *)(buffer + index));
index += GraphNode::id_size;
}
}
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
;
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
closure->request(0)->set_table_id(table_id);
closure->request(0)->set_client_id(_client_id);
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
;
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
std::future<int32_t> GraphBrpcClient::pull_graph_list(
uint32_t table_id, int server_index, int start, int size,
std::vector<GraphNode> &res) {
Expand All @@ -153,7 +196,7 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
while (index < bytes_size) {
GraphNode node;
node.recover_from_buffer(buffer + index);
index += node.get_size();
index += node.get_size(true);
res.push_back(node);
}
}
Expand Down
10 changes: 7 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,17 @@ class GraphBrpcClient : public BrpcPsClient {
public:
GraphBrpcClient() {}
virtual ~GraphBrpcClient() {}
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res);
virtual std::future<int32_t> batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res);
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
int server_index, int start,
int size,
std::vector<GraphNode> &res);
virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index,
int sample_size,
std::vector<uint64_t> &ids);
virtual int32_t initialize();
int get_shard_num() { return shard_num; }
void set_shard_num(int shard_num) { this->shard_num = shard_num; }
Expand Down
41 changes: 22 additions & 19 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ int32_t GraphBrpcService::initialize() {
_service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;

_service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
_service_handler_map[PS_GRAPH_SAMPLE] =
&GraphBrpcService::graph_random_sample;
_service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] =
&GraphBrpcService::graph_random_sample_neighboors;
_service_handler_map[PS_GRAPH_SAMPLE_NODES] =
&GraphBrpcService::graph_random_sample_nodes;

// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
Expand Down Expand Up @@ -267,14 +269,13 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
int size = *(int *)(request.params(1).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
table->pull_graph_list(start, size, buffer, actual_size);
table->pull_graph_list(start, size, buffer, actual_size, true);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t GraphBrpcService::graph_random_sample_neighboors(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
Expand All @@ -285,29 +286,31 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
size_t node_num = request.params(0).size() / sizeof(uint64_t);
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());

std::vector<std::unique_ptr<char[]>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
table->random_sample(node_data, sample_size, buffers, actual_sizes);
table->random_sample_neighboors(node_data, sample_size, buffers,
actual_sizes);

cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(),
sizeof(int) * node_num);
for (size_t idx = 0; idx < node_num; ++idx) {
cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
// if (buffers[idx] != nullptr){
// delete buffers[idx];
// buffers[idx] = nullptr;
// }
}
// =======
// std::unique_ptr<char[]> buffer;
// int actual_size;
// table->random_sample(node_id, sample_size, buffer, actual_size);
// cntl->response_attachment().append(buffer.get(), actual_size);
// >>>>>>> Stashed changes
return 0;
}
int32_t GraphBrpcService::graph_random_sample_nodes(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl) {
size_t size = *(uint64_t *)(request.params(0).c_str());
std::unique_ptr<char[]> buffer;
int actual_size;
if (table->random_sample_nodes(size, buffer, actual_size) == 0) {
cntl->response_attachment().append(buffer.get(), actual_size);
} else
cntl->response_attachment().append(NULL, 0);

return 0;
}
} // namespace distributed
} // namespace paddle
12 changes: 9 additions & 3 deletions paddle/fluid/distributed/service/graph_brpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,14 @@ class GraphBrpcService : public PsBaseService {
int32_t initialize_shard_info();
int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t graph_random_sample(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_neighboors(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t graph_random_sample_nodes(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
int32_t barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t load_one_table(Table *table, const PsRequestMessage &request,
Expand All @@ -103,6 +108,7 @@ class GraphBrpcService : public PsBaseService {
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _msg_handler_map;
std::vector<float> _ori_values;
const int sample_nodes_ranges = 3;
};
// class GraphBrpcService : public BrpcPsService {
// public:
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/distributed/service/graph_py_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,14 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
status.wait();
}
}
std::vector<std::vector<std::pair<uint64_t, float> > > GraphPyClient::batch_sample_k(
std::string name, std::vector<uint64_t> node_ids, int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float> > > v;
std::vector<std::vector<std::pair<uint64_t, float>>>
GraphPyClient::batch_sample_k(std::string name, std::vector<uint64_t> node_ids,
int sample_size) {
std::vector<std::vector<std::pair<uint64_t, float>>> v;
if (this->table_id_map.count(name)) {
uint32_t table_id = this->table_id_map[name];
auto status = worker_ptr->batch_sample(table_id, node_ids, sample_size, v);
auto status =
worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v);
status.wait();
}
return v;
Expand Down
17 changes: 14 additions & 3 deletions paddle/fluid/distributed/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ class PSClient {
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
virtual std::future<int32_t> batch_sample_neighboors(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
Expand All @@ -174,6 +174,17 @@ class PSClient {
promise.set_value(-1);
return fut;
}

virtual std::future<int32_t> random_sample_nodes(uint32_t table_id,
int server_index,
int sample_size,
std::vector<uint64_t> &ids) {
LOG(FATAL) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
// client2client消息处理,std::function<int32_t (int, int, const std::string&)
// -> ret (msg_type, from_client_id, msg)
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/service/sendrecv.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ enum PsCmdID {
PS_STOP_PROFILER = 28;
PS_PUSH_GLOBAL_STEP = 29;
PS_PULL_GRAPH_LIST = 30;
PS_GRAPH_SAMPLE = 31;
PS_GRAPH_SAMPLE_NEIGHBOORS = 31;
PS_GRAPH_SAMPLE_NODES = 32;
}

message PsRequestMessage {
Expand Down
Loading