Skip to content

Commit eb53bfa

Browse files
authored
Merge pull request #4 from WeiyueSu/batch_sample_k
random_sample return future
2 parents 2abf38c + 2a70bd8 commit eb53bfa

File tree

5 files changed

+35
-35
lines changed

5 files changed

+35
-35
lines changed

paddle/fluid/distributed/service/graph_brpc_server.cc

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -283,28 +283,22 @@ int32_t GraphBrpcService::graph_random_sample(Table *table,
283283
"graph_random_sample request requires at least 2 arguments");
284284
return 0;
285285
}
286-
size_t num_nodes = request.params(0).size() / sizeof(uint64_t);
286+
size_t node_num = request.params(0).size() / sizeof(uint64_t);
287287
uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
288288
int sample_size = *(uint64_t *)(request.params(1).c_str());
289289

290-
std::vector<std::future<int>*> tasks;
291-
std::vector<char*> buffers(num_nodes);
292-
std::vector<int> actual_sizes(num_nodes);
290+
std::vector<char*> buffers(node_num, nullptr);
291+
std::vector<int> actual_sizes(node_num, 0);
292+
table->random_sample(node_data, sample_size, buffers, actual_sizes);
293293

294-
for (size_t idx = 0; idx < num_nodes; ++idx){
295-
//std::future<int> task = table->random_sample(node_data[idx], sample_size,
296-
//buffers[idx], actual_sizes[idx]);
297-
table->random_sample(node_data[idx], sample_size,
298-
buffers[idx], actual_sizes[idx]);
299-
//tasks.push_back(&task);
300-
}
301-
//for (size_t idx = 0; idx < num_nodes; ++idx){
302-
//tasks[idx]->get();
303-
//}
304-
cntl->response_attachment().append(&num_nodes, sizeof(size_t));
305-
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*num_nodes);
306-
for (size_t idx = 0; idx < num_nodes; ++idx){
294+
cntl->response_attachment().append(&node_num, sizeof(size_t));
295+
cntl->response_attachment().append(actual_sizes.data(), sizeof(int)*node_num);
296+
for (size_t idx = 0; idx < node_num; ++idx){
307297
cntl->response_attachment().append(buffers[idx], actual_sizes[idx]);
298+
if (buffers[idx] != nullptr){
299+
delete buffers[idx];
300+
buffers[idx] = nullptr;
301+
}
308302
}
309303
return 0;
310304
}

paddle/fluid/distributed/table/common_graph_table.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,16 @@ GraphNode *GraphTable::find_node(uint64_t id) {
200200
uint32_t GraphTable::get_thread_pool_index(uint64_t node_id) {
201201
return node_id % shard_num_per_table % task_pool_size_;
202202
}
203-
//std::future<int> GraphTable::random_sample(uint64_t node_id, int sample_size,
204-
//char *&buffer, int &actual_size) {
205-
int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
206-
char *&buffer, int &actual_size) {
207-
return _shards_task_pool[get_thread_pool_index(node_id)]
203+
int GraphTable::random_sample(uint64_t* node_ids, int sample_size,
204+
std::vector<char*>& buffers, std::vector<int> &actual_sizes) {
205+
size_t node_num = buffers.size();
206+
std::vector<std::future<int>> tasks;
207+
for (size_t idx = 0; idx < node_num; ++idx){
208+
uint64_t node_id = node_ids[idx];
209+
char* & buffer = buffers[idx];
210+
int& actual_size = actual_sizes[idx];
211+
tasks.push_back(_shards_task_pool[get_thread_pool_index(node_id)]
208212
->enqueue([&]() -> int {
209-
210213
GraphNode *node = find_node(node_id);
211214
if (node == NULL) {
212215
actual_size = 0;
@@ -229,8 +232,12 @@ int32_t GraphTable::random_sample(uint64_t node_id, int sample_size,
229232
offset += GraphNode::weight_size;
230233
}
231234
return 0;
232-
})
233-
.get();
235+
}));
236+
}
237+
for (size_t idx = 0; idx < node_num; ++idx){
238+
tasks[idx].get();
239+
}
240+
return 0;
234241
}
235242
int32_t GraphTable::pull_graph_list(int start, int total_size, char *&buffer,
236243
int &actual_size) {

paddle/fluid/distributed/table/common_graph_table.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,8 @@ class GraphTable : public SparseTable {
7171
virtual ~GraphTable() {}
7272
virtual int32_t pull_graph_list(int start, int size, char *&buffer,
7373
int &actual_size);
74-
//virtual std::future<int> random_sample(uint64_t node_id, int sampe_size, char *&buffer,
75-
//int &actual_size);
76-
virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer,
77-
int &actual_size);
74+
virtual int random_sample(uint64_t* node_ids, int sampe_size, std::vector<char *>&buffers,
75+
std::vector<int> &actual_sizes);
7876
virtual int32_t initialize();
7977

8078
int32_t load(const std::string &path, const std::string &param);

paddle/fluid/distributed/table/table.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,10 @@ class Table {
9393
return 0;
9494
}
9595
// only for graph table
96-
virtual int32_t random_sample(uint64_t node_id, int sampe_size, char *&buffer,
97-
int &actual_size) {
96+
virtual int random_sample(uint64_t* node_ids, int sampe_size, std::vector<char *>&buffers,
97+
std::vector<int> &actual_sizes) {
9898
return 0;
9999
}
100-
//virtual std::future<int> random_sample(uint64_t node_id, int sampe_size, char *&buffer,
101-
//int &actual_size) {
102-
//return std::future<int>();
103-
//}
104100
virtual int32_t pour() { return 0; }
105101

106102
virtual void clear() = 0;

paddle/fluid/distributed/test/graph_node_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ void RunBrpcPushSparse() {
236236
ASSERT_EQ(true, s.find(g.first) != s.end());
237237
}
238238
vs.clear();
239+
240+
pull_status = worker_ptr_->batch_sample(0, std::vector<uint64_t>(1, 10240001024), 4, vs);
241+
pull_status.wait();
242+
ASSERT_EQ(0, vs[0].size());
243+
239244
std::vector<distributed::GraphNode> nodes;
240245
pull_status = worker_ptr_->pull_graph_list(0, 0, 0, 1, nodes);
241246
pull_status.wait();

0 commit comments

Comments
 (0)