Skip to content

Commit 1334f2c

Browse files
authored
Merge pull request #6 from seemingwang/develop
Merge SampleK
2 parents 2feadfe + eb53bfa commit 1334f2c

File tree

11 files changed

+153
-98
lines changed

11 files changed

+153
-98
lines changed

paddle/fluid/distributed/service/graph_brpc_client.cc

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,43 +35,91 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
3535
return id % shard_num / shard_per_server;
3636
}
3737
// char* &buffer,int &actual_size
38-
std::future<int32_t> GraphBrpcClient::sample(
39-
uint32_t table_id, uint64_t node_id, int sample_size,
40-
std::vector<std::pair<uint64_t, float>> &res) {
41-
int server_index = get_server_index_by_id(node_id);
42-
DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
38+
std::future<int32_t> GraphBrpcClient::batch_sample(uint32_t table_id,
39+
std::vector<uint64_t> node_ids, int sample_size,
40+
std::vector<std::vector<std::pair<uint64_t, float> > > &res) {
41+
42+
std::vector<int> request2server;
43+
std::vector<int> server2request(server_size, -1);
44+
res.clear();
45+
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
46+
int server_index = get_server_index_by_id(node_ids[query_idx]);
47+
if(server2request[server_index] == -1){
48+
server2request[server_index] = request2server.size();
49+
request2server.push_back(server_index);
50+
}
51+
//res.push_back(std::vector<GraphNode>());
52+
res.push_back(std::vector<std::pair<uint64_t, float>>());
53+
}
54+
size_t request_call_num = request2server.size();
55+
std::vector<std::vector<uint64_t> > node_id_buckets(request_call_num);
56+
std::vector<std::vector<int> > query_idx_buckets(request_call_num);
57+
for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx){
58+
int server_index = get_server_index_by_id(node_ids[query_idx]);
59+
int request_idx = server2request[server_index];
60+
node_id_buckets[request_idx].push_back(node_ids[query_idx]);
61+
query_idx_buckets[request_idx].push_back(query_idx);
62+
}
63+
64+
DownpourBrpcClosure *closure = new DownpourBrpcClosure(request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
4365
int ret = 0;
4466
auto *closure = (DownpourBrpcClosure *)done;
45-
if (closure->check_response(0, PS_GRAPH_SAMPLE) != 0) {
46-
ret = -1;
47-
} else {
48-
auto &res_io_buffer = closure->cntl(0)->response_attachment();
49-
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
50-
size_t bytes_size = io_buffer_itr.bytes_left();
51-
char *buffer = new char[bytes_size];
52-
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
53-
int offset = 0;
54-
while (offset < bytes_size) {
55-
res.push_back({*(uint64_t *)(buffer + offset),
56-
*(float *)(buffer + offset + GraphNode::id_size)});
57-
offset += GraphNode::id_size + GraphNode::weight_size;
67+
int fail_num = 0;
68+
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
69+
if (closure->check_response(request_idx, PS_GRAPH_SAMPLE) != 0) {
70+
++fail_num;
71+
} else {
72+
VLOG(0) << "check sample response: "
73+
<< " " << closure->check_response(request_idx, PS_GRAPH_SAMPLE);
74+
auto &res_io_buffer = closure->cntl(request_idx)->response_attachment();
75+
butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
76+
size_t bytes_size = io_buffer_itr.bytes_left();
77+
char *buffer = new char[bytes_size];
78+
io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
79+
80+
size_t node_num = *(size_t *)buffer;
81+
int *actual_sizes = (int *)(buffer + sizeof(size_t));
82+
char *node_buffer = buffer + sizeof(size_t) + sizeof(int) * node_num;
83+
84+
int offset = 0;
85+
for (size_t node_idx = 0; node_idx < node_num; ++node_idx){
86+
int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
87+
int actual_size = actual_sizes[node_idx];
88+
int start = 0;
89+
while (start < actual_size) {
90+
res[query_idx].push_back({*(uint64_t *)(node_buffer + offset + start),
91+
*(float *)(node_buffer + offset + start + GraphNode::id_size)});
92+
start += GraphNode::id_size + GraphNode::weight_size;
93+
}
94+
offset += actual_size;
95+
}
96+
}
97+
if (fail_num == request_call_num){
98+
ret = -1;
5899
}
59100
}
60101
closure->set_promise_value(ret);
61102
});
103+
62104
auto promise = std::make_shared<std::promise<int32_t>>();
63105
closure->add_promise(promise);
64106
std::future<int> fut = promise->get_future();
65-
;
66-
closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE);
67-
closure->request(0)->set_table_id(table_id);
68-
closure->request(0)->set_client_id(_client_id);
69-
closure->request(0)->add_params((char *)&node_id, sizeof(uint64_t));
70-
closure->request(0)->add_params((char *)&sample_size, sizeof(int));
71-
PsService_Stub rpc_stub(get_cmd_channel(server_index));
72-
closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
73-
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
74-
closure);
107+
108+
for (int request_idx = 0; request_idx < request_call_num; ++request_idx){
109+
int server_index = request2server[request_idx];
110+
closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE);
111+
closure->request(request_idx)->set_table_id(table_id);
112+
closure->request(request_idx)->set_client_id(_client_id);
113+
// std::string type_str = GraphNode::node_type_to_string(type);
114+
size_t node_num = node_id_buckets[request_idx].size();
115+
116+
closure->request(request_idx)->add_params((char *)node_id_buckets[request_idx].data(), sizeof(uint64_t)*node_num);
117+
closure->request(request_idx)->add_params((char *)&sample_size, sizeof(int));
118+
PsService_Stub rpc_stub(get_cmd_channel(server_index));
119+
closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
120+
rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx), closure->response(request_idx),
121+
closure);
122+
}
75123

76124
return fut;
77125
}
@@ -124,4 +172,4 @@ int32_t GraphBrpcClient::initialize() {
124172
return 0;
125173
}
126174
}
127-
}
175+
}

paddle/fluid/distributed/service/graph_brpc_client.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ class GraphBrpcClient : public BrpcPsClient {
3636
public:
3737
GraphBrpcClient() {}
3838
virtual ~GraphBrpcClient() {}
39-
virtual std::future<int32_t> sample(
40-
uint32_t table_id, uint64_t node_id, int sample_size,
41-
std::vector<std::pair<uint64_t, float>> &res);
39+
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
40+
int sample_size,
41+
std::vector<std::vector<std::pair<uint64_t, float>>> &res);
4242
virtual std::future<int32_t> pull_graph_list(uint32_t table_id,
4343
int server_index, int start,
4444
int size,

paddle/fluid/distributed/service/graph_brpc_server.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,23 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
283283
"graph_random_sample request requires at least 2 arguments");
284284
return 0;
285285
}
286-
uint64_t node_id = *(uint64_t *)(request.params(0).c_str());
286+
size_t node_num = request.params(0).size() / sizeof(uint64_t);
287+
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
287288
int sample_size = *(uint64_t *)(request.params(1).c_str());
288-
char *buffer;
289-
int actual_size;
290-
table->random_sample(node_id, sample_size, buffer, actual_size);
291-
cntl->response_attachment().append(buffer, actual_size);
289+
290+
std::vector<char*> buffers(node_num, nullptr);
291+
std::vector<int> actual_sizes(node_num, 0);
292+
table->random_sample(node_data, sample_size, buffers, actual_sizes);
293+
294+
cntl->response_attachment().append(&node_num, sizeof(size_t));
295+
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*node_num);
296+
for (size_t idx = 0; idx < node_num; ++idx){
297+
cntl->response_attachment().append(buffers[idx], actual_sizes[idx]);
298+
if (buffers[idx] != nullptr){
299+
delete buffers[idx];
300+
buffers[idx] = nullptr;
301+
}
302+
}
292303
return 0;
293304
}
294305

paddle/fluid/distributed/service/graph_py_service.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ void GraphPyClient::load_node_file(std::string name, std::string filepath) {
183183
status.wait();
184184
}
185185
}
186-
std::vector<std::pair<uint64_t, float>> GraphPyClient::sample_k(
187-
std::string name, uint64_t node_id, int sample_size) {
188-
std::vector<std::pair<uint64_t, float>> v;
186+
std::vector<std::vector<std::pair<uint64_t, float> > > GraphPyClient::batch_sample_k(
187+
std::string name, std::vector<uint64_t> node_ids, int sample_size) {
188+
std::vector<std::vector<std::pair<uint64_t, float> > > v;
189189
if (this->table_id_map.count(name)) {
190190
uint32_t table_id = this->table_id_map[name];
191-
auto status = worker_ptr->sample(table_id, node_id, sample_size, v);
191+
auto status = worker_ptr->batch_sample(table_id, node_ids, sample_size, v);
192192
status.wait();
193193
}
194194
return v;

paddle/fluid/distributed/service/graph_py_service.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ class GraphPyClient : public GraphPyService {
119119
int get_client_id() { return client_id; }
120120
void set_client_id(int client_id) { this->client_id = client_id; }
121121
void start_client();
122-
std::vector<std::pair<uint64_t, float>> sample_k(std::string name,
123-
uint64_t node_id,
124-
int sample_size);
122+
std::vector<std::vector<std::pair<uint64_t, float> > > batch_sample_k(
123+
std::string name, std::vector<uint64_t> node_ids, int sample_size);
125124
std::vector<GraphNode> pull_graph_list(std::string name, int server_index,
126125
int start, int size);
127126
::paddle::distributed::PSParameter GetWorkerProto();

paddle/fluid/distributed/service/ps_client.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ class PSClient {
155155
promise.set_value(-1);
156156
return fut;
157157
}
158-
virtual std::future<int32_t> sample(
159-
uint32_t table_id, uint64_t node_id, int sample_size,
160-
std::vector<std::pair<uint64_t, float>> &res) {
158+
virtual std::future<int32_t> batch_sample(uint32_t table_id, std::vector<uint64_t> node_ids,
159+
int sample_size,
160+
std::vector<std::vector<std::pair<uint64_t, float>>> &res) {
161161
LOG(FATAL) << "Did not implement";
162162
std::promise<int32_t> promise;
163163
std::future<int> fut = promise.get_future();

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,16 @@ GraphNode *GraphTable::find_node(uint64_t id) {
200200
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
201201
return node_id % shard_num_per_table % task_pool_size_;
202202
}
203-
int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
204-
char *&buffer, int &actual_size) {
205-
return _shards_task_pool[get_thread_pool_index(node_id)]
203+
int GraphTable::random_sample(uint64_t* node_ids, int sample_size,
204+
std::vector<char*>& buffers, std::vector<int> &actual_sizes) {
205+
size_t node_num = buffers.size();
206+
std::vector<std::future<int>> tasks;
207+
for (size_t idx = 0; idx < node_num; ++idx){
208+
uint64_t node_id = node_ids[idx];
209+
char* & buffer = buffers[idx];
210+
int& actual_size = actual_sizes[idx];
211+
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]
206212
->enqueue([&]() -> int {
207-
208213
GraphNode *node = find_node(node_id);
209214
if (node == NULL) {
210215
actual_size = 0;
@@ -226,8 +231,13 @@ int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
226231
memcpy(buffer + offset, &weight, GraphNode::weight_size);
227232
offset += GraphNode::weight_size;
228233
}
229-
})
230-
.get();
234+
return 0;
235+
}));
236+
}
237+
for (size_t idx = 0; idx < node_num; ++idx){
238+
tasks[idx].get();
239+
}
240+
return 0;
231241
}
232242
int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
233243
int &actual_size) {

paddle/fluid/distributed/table/common_graph_table.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ class GraphTable : public SparseTable {
7171
virtual ~GraphTable() {}
7272
virtual int32_t pull_graph_list(int start, int size, char *&buffer,
7373
int &actual_size);
74-
virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer,
75-
int &actual_size);
74+
virtual int random_sample(uint64_t* node_ids, int sampe_size, std::vector<char *>&buffers,
75+
std::vector<int> &actual_sizes);
7676
virtual int32_t initialize();
7777

7878
int32_t load(const std::string &path, const std::string &param);

paddle/fluid/distributed/table/table.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ class Table {
9393
return 0;
9494
}
9595
// only for graph table
96-
virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer,
97-
int &actual_size) {
96+
virtual int random_sample(uint64_t* node_ids, int sampe_size, std::vector<char *>&buffers,
97+
std::vector<int> &actual_sizes) {
9898
return 0;
9999
}
100100
virtual int32_t pour() { return 0; }

paddle/fluid/distributed/test/graph_node_test.cc

Lines changed: 26 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -216,23 +216,31 @@ void RunBrpcPushSparse() {
216216

217217
/*-----------------------Test Server Init----------------------------------*/
218218
auto pull_status =
219-
worker_ptr_->load(0, std::string(file_name), std::string(""));
219+
worker_ptr_->load(0, std::string(file_name), std::string("edge"));
220220

221221
pull_status.wait();
222-
std::vector<std::pair<uint64_t, float>> v;
223-
pull_status = worker_ptr_->sample(0, 37, 4, v);
222+
std::vector<std::vector<std::pair<uint64_t, float> > > vs;
223+
//std::vector<std::pair<uint64_t, float>> v;
224+
//pull_status = worker_ptr_->sample(0, 37, 4, v);
225+
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 37), 4, vs);
224226
pull_status.wait();
225-
ASSERT_EQ(v.size(), 3);
226-
v.clear();
227-
pull_status = worker_ptr_->sample(0, 96, 4, v);
227+
ASSERT_EQ(vs[0].size(), 3);
228+
vs.clear();
229+
//pull_status = worker_ptr_->sample(0, 96, 4, v);
230+
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 96), 4, vs);
228231
pull_status.wait();
229232
std::unordered_set<int> s = {111, 48, 247};
230-
ASSERT_EQ(3, v.size());
231-
for (auto g : v) {
233+
ASSERT_EQ(3, vs[0].size());
234+
for (auto g : vs[0]) {
232235
// std::cout << g.first << std::endl;
233236
ASSERT_EQ(true, s.find(g.first) != s.end());
234237
}
235-
v.clear();
238+
vs.clear();
239+
240+
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 10240001024), 4, vs);
241+
pull_status.wait();
242+
ASSERT_EQ(0, vs[0].size());
243+
236244
std::vector<distributed::GraphNode> nodes;
237245
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes);
238246
pull_status.wait();
@@ -276,38 +284,17 @@ void RunBrpcPushSparse() {
276284
nodes = client2.pull_graph_list(std::string("user2item"), 0, 1, 4);
277285
ASSERT_EQ(nodes[0].get_id(), 59);
278286
nodes.clear();
279-
v = client1.sample_k(std::string("user2item"), 96, 4);
280-
ASSERT_EQ(v.size(), 3);
281-
std::cout << "sample result" << std::endl;
282-
for (auto p : v) {
287+
vs = client1.batch_sample_k(std::string("user2item"), std::vector<uint64_t>(1, 96), 4);
288+
ASSERT_EQ(vs[0].size(), 3);
289+
std::cout << "batch sample result" << std::endl;
290+
for (auto p : vs[0]) {
283291
std::cout << p.first << " " << p.second << std::endl;
284292
}
285-
/*
286-
from paddle.fluid.core import GraphPyService
287-
ips_str = "127.0.0.1:4211;127.0.0.1:4212"
288-
server1 = GraphPyServer()
289-
server2 = GraphPyServer()
290-
client1 = GraphPyClient()
291-
client2 = GraphPyClient()
292-
edge_types = ["user2item"]
293-
server1.set_up(ips_str,127,edge_types,0);
294-
server2.set_up(ips_str,127,edge_types,1);
295-
client1.set_up(ips_str,127,edge_types,0);
296-
client2.set_up(ips_str,127,edge_types,1);
297-
server1.start_server();
298-
server2.start_server();
299-
client1.start_client();
300-
client2.start_client();
301-
client1.load_edge_file(user2item", "input.txt", 0);
302-
list = client2.pull_graph_list("user2item",0,1,4)
303-
for x in list:
304-
print(x.get_id())
305-
306-
list = client1.sample_k("user2item",96, 4);
307-
for x in list:
308-
print(x.get_id())
309-
*/
310-
293+
std::vector<uint64_t> node_ids;
294+
node_ids.push_back(96);
295+
node_ids.push_back(37);
296+
vs = client1.batch_sample_k(std::string("user2item"), node_ids, 4);
297+
ASSERT_EQ(vs.size(), 2);
311298
// to test in python,try this:
312299
// from paddle.fluid.core import GraphPyService
313300
// ips_str = "127.0.0.1:4211;127.0.0.1:4212"

0 commit comments

Comments
 (0)