Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 65 additions & 56 deletions paddle/fluid/distributed/service/graph_brpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,90 +35,99 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
return id % shard_num / shard_per_server;
}
// char* &buffer,int &actual_size
std::future<int32_t> GraphBrpcClient::batch_sample(uint32_t table_id,
std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float> > > &res) {

std::future<int32_t> GraphBrpcClient::batch_sample(
uint32_t table_id, std::vector<uint64_t> node_ids, int sample_size,
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
std::vector<int> request2server;
std::vector<int> server2request(server_size, -1);
res.clear();
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
if(server2request[server_index] == -1){
if (server2request[server_index] == -1) {
server2request[server_index] = request2server.size();
request2server.push_back(server_index);
}
//res.push_back(std::vector<GraphNode>());
// res.push_back(std::vector<GraphNode>());
res.push_back(std::vector<std::pair<uint64_t, float>>());
}
size_t request_call_num = request2server.size();
std::vector<std::vector<uint64_t> > node_id_buckets(request_call_num);
std::vector<std::vector<int> > query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
int server_index = get_server_index_by_id(node_ids[query_idx]);
int request_idx = server2request[server_index];
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
query_idx_buckets[request_idx].push_back(query_idx);
}

DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
++fail_num;
} else {
VLOG(0) << "check sample response: "
<< " " << closure->check_response(request_idx, PS_GRAPH_SAMPLE);
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
int fail_num = 0;
for (int request_idx = 0; request_idx < request_call_num;
++request_idx) {
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
++fail_num;
} else {
auto &res_io_buffer =
closure->cntl(request_idx)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
// char buffer[bytes_size];
std::unique_ptr<char[]> buffer_wrapper(new char[bytes_size]);
char *buffer = buffer_wrapper.get();
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);

size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;

int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx){
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back({*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start + GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
size_t node_num = *(size_t *)buffer;
int *actual_sizes = (int *)(buffer + sizeof(size_t));
char *node_buffer =
buffer + sizeof(size_t) + sizeof(int) * node_num;

int offset = 0;
for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
int actual_size = actual_sizes[node_idx];
int start = 0;
while (start < actual_size) {
res[query_idx].push_back(
{*(uint64_t *)(node_buffer + offset + start),
*(float *)(node_buffer + offset + start +
GraphNode::id_size)});
start += GraphNode::id_size + GraphNode::weight_size;
}
offset += actual_size;
}
}
if (fail_num == request_call_num) {
ret = -1;
}
offset += actual_size;
}
}
if (fail_num == request_call_num){
ret = -1;
}
}
closure->set_promise_value(ret);
});
closure->set_promise_value(ret);
});

auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){

for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
int server_index = request2server[request_idx];
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE);
closure->request(request_idx)->set_table_id(table_id);
closure->request(request_idx)->set_client_id(_client_id);
// std::string type_str = GraphNode::node_type_to_string(type);
size_t node_num = node_id_buckets[request_idx].size();

closure->request(request_idx)->add_params((char *)node_id_buckets[request_idx].data(), sizeof(uint64_t)*node_num);
closure->request(request_idx)->add_params((char *)&sample_size, sizeof(int));

closure->request(request_idx)
->add_params((char *)node_id_buckets[request_idx].data(),
sizeof(uint64_t) * node_num);
closure->request(request_idx)
->add_params((char *)&sample_size, sizeof(int));
PsService_Stub rpc_stub(get_cmd_channel(server_index));
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx),
closure);
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
closure->response(request_idx), closure);
}

return fut;
Expand All @@ -133,12 +142,12 @@ std::future<int32_t> GraphBrpcClient::pull_graph_list(
if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
ret = -1;
} else {
VLOG(0) << "check sample response: "
<< " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
// VLOG(0) << "check sample response: "
// << " " << closure->check_response(0, PS_PULL_GRAPH_LIST);
auto &res_io_buffer = closure->cntl(0)->response_attachment();
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
size_t bytes_size = io_buffer_itr.bytes_left();
char *buffer = new char[bytes_size];
char buffer[bytes_size];
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
int index = 0;
while (index < bytes_size) {
Expand Down
28 changes: 17 additions & 11 deletions paddle/fluid/distributed/service/graph_brpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,10 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
}
int start = *(int *)(request.params(0).c_str());
int size = *(int *)(request.params(1).c_str());
std::vector<float> res_data;
char *buffer;
std::unique_ptr<char[]> buffer;
int actual_size;
table->pull_graph_list(start, size, buffer, actual_size);
cntl->response_attachment().append(buffer, actual_size);
cntl->response_attachment().append(buffer.get(), actual_size);
return 0;
}
int32_t GraphBrpcService::graph_random_sample(Table *table,
Expand All @@ -287,19 +286,26 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
int sample_size = *(uint64_t *)(request.params(1).c_str());

std::vector<char*> buffers(node_num, nullptr);
std::vector<std::unique_ptr<char[]>> buffers(node_num);
std::vector<int> actual_sizes(node_num, 0);
table->random_sample(node_data, sample_size, buffers, actual_sizes);

cntl->response_attachment().append(&node_num, sizeof(size_t));
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*node_num);
for (size_t idx = 0; idx < node_num; ++idx){
cntl->response_attachment().append(buffers[idx], actual_sizes[idx]);
if (buffers[idx] != nullptr){
delete buffers[idx];
buffers[idx] = nullptr;
}
cntl->response_attachment().append(actual_sizes.data(),
sizeof(int) * node_num);
for (size_t idx = 0; idx < node_num; ++idx) {
cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
// if (buffers[idx] != nullptr){
// delete buffers[idx];
// buffers[idx] = nullptr;
// }
}
// =======
// std::unique_ptr<char[]> buffer;
// int actual_size;
// table->random_sample(node_id, sample_size, buffer, actual_size);
// cntl->response_attachment().append(buffer.get(), actual_size);
// >>>>>>> Stashed changes
return 0;
}

Expand Down
96 changes: 47 additions & 49 deletions paddle/fluid/distributed/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ size_t GraphShard::get_size() {
return res;
}

std::list<GraphNode *>::iterator GraphShard::add_node(uint64_t id, std::string feature) {
std::list<GraphNode *>::iterator GraphShard::add_node(uint64_t id,
std::string feature) {
if (node_location.find(id) != node_location.end())
return node_location.find(id)->second;

Expand All @@ -89,14 +90,13 @@ GraphNode *GraphShard::find_node(uint64_t id) {

int32_t GraphTable::load(const std::string &path, const std::string &param) {
auto cmd = paddle::string::split_string<std::string>(param, "|");
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
std::set<std::string> cmd_set(cmd.begin(), cmd.end());
bool reverse_edge = cmd_set.count(std::string("reverse"));
bool load_edge = cmd_set.count(std::string("edge"));
if(load_edge) {
return this -> load_edges(path, reverse_edge);
}
else {
return this -> load_nodes(path);
if (load_edge) {
return this->load_edges(path, reverse_edge);
} else {
return this->load_nodes(path);
}
}

Expand All @@ -110,33 +110,28 @@ int32_t GraphTable::load_nodes(const std::string &path) {
if (values.size() < 2) continue;
auto id = std::stoull(values[1]);


size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(0) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;

}

std::string node_type = values[0];
std::vector<std::string > feature;
std::vector<std::string> feature;
feature.push_back(node_type);
for(size_t slice = 2; slice < values.size(); slice ++) {
for (size_t slice = 2; slice < values.size(); slice++) {
feature.push_back(values[slice]);
}
auto feat = paddle::string::join_strings(feature, '\t');
size_t index = shard_id - shard_start;
shards[index].add_node(id, feat);

}
}
return 0;
}


int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {

auto paths = paddle::string::split_string<std::string>(path, ";");
int count = 0;

Expand Down Expand Up @@ -173,7 +168,7 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge) {
VLOG(0) << "Load Finished Total Edge Count " << count;

// Build Sampler j

for (auto &shard : shards) {
auto bucket = shard.get_bucket();
for (int i = 0; i < bucket.size(); i++) {
Expand All @@ -200,46 +195,49 @@ GraphNode *GraphTable::find_node(uint64_t id) {
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
return node_id % shard_num_per_table % task_pool_size_;
}
int GraphTable::random_sample(uint64_t* node_ids, int sample_size,
std::vector<char*>& buffers, std::vector<int> &actual_sizes) {
int GraphTable::random_sample(uint64_t *node_ids, int sample_size,
std::vector<std::unique_ptr<char[]>> &buffers,
std::vector<int> &actual_sizes) {
size_t node_num = buffers.size();
std::vector<std::future<int>> tasks;
for (size_t idx = 0; idx < node_num; ++idx){
for (size_t idx = 0; idx < node_num; ++idx) {
uint64_t node_id = node_ids[idx];
char* & buffer = buffers[idx];
int& actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]
->enqueue([&]() -> int {
GraphNode *node = find_node(node_id);
if (node == NULL) {
actual_size = 0;
std::unique_ptr<char[]> &buffer = buffers[idx];
int &actual_size = actual_sizes[idx];
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]->enqueue(
[&]() -> int {
GraphNode *node = find_node(node_id);
if (node == NULL) {
actual_size = 0;
return 0;
}
std::vector<GraphEdge *> res = node->sample_k(sample_size);
actual_size =
res.size() * (GraphNode::id_size + GraphNode::weight_size);
int offset = 0;
uint64_t id;
float weight;
char *buffer_addr = new char[actual_size];
buffer.reset(buffer_addr);
for (auto &x : res) {
id = x->get_id();
weight = x->get_weight();
memcpy(buffer_addr + offset, &id, GraphNode::id_size);
offset += GraphNode::id_size;
memcpy(buffer_addr + offset, &weight, GraphNode::weight_size);
offset += GraphNode::weight_size;
return 0;
}
return 0;
}
std::vector<GraphEdge *> res = node->sample_k(sample_size);
std::vector<GraphNode> node_list;
actual_size =
res.size() * (GraphNode::id_size + GraphNode::weight_size);
buffer = new char[actual_size];
int offset = 0;
uint64_t id;
float weight;
for (auto &x : res) {
id = x->get_id();
weight = x->get_weight();
memcpy(buffer + offset, &id, GraphNode::id_size);
offset += GraphNode::id_size;
memcpy(buffer + offset, &weight, GraphNode::weight_size);
offset += GraphNode::weight_size;
}
return 0;
}));
}));
}
for (size_t idx = 0; idx < node_num; ++idx){
for (size_t idx = 0; idx < node_num; ++idx) {
tasks[idx].get();
}
return 0;
}
int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
int32_t GraphTable::pull_graph_list(int start, int total_size,
std::unique_ptr<char[]> &buffer,
int &actual_size) {
if (start < 0) start = 0;
int size = 0, cur_size;
Expand Down Expand Up @@ -283,11 +281,12 @@ int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
size += res.back()[j]->get_size();
}
}
buffer = new char[size];
char *buffer_addr = new char[size];
buffer.reset(buffer_addr);
int index = 0;
for (size_t i = 0; i < res.size(); i++) {
for (size_t j = 0; j < res[i].size(); j++) {
res[i][j]->to_buffer(buffer + index);
res[i][j]->to_buffer(buffer_addr + index);
index += res[i][j]->get_size();
}
}
Expand Down Expand Up @@ -321,4 +320,3 @@ int32_t GraphTable::initialize() {
}
}
};

Loading