Skip to content

Commit 2457680

Browse files
authored
Merge pull request #12 from seemingwang/gpu_graph_engine3
changing int64 key to uint64 for graph engine
2 parents b01e178 + eae32ae commit 2457680

20 files changed

+430
-879
lines changed

paddle/fluid/distributed/ps/service/graph_brpc_server.cc

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,8 @@ int32_t GraphBrpcService::add_graph_node(Table *table,
143143

144144
int idx_ = *(int *)(request.params(0).c_str());
145145
size_t node_num = request.params(1).size() / sizeof(int64_t);
146-
int64_t *node_data = (int64_t *)(request.params(1).c_str());
147-
// size_t node_num = request.params(0).size() / sizeof(int64_t);
148-
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
149-
std::vector<int64_t> node_ids(node_data, node_data + node_num);
146+
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
147+
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
150148
std::vector<bool> is_weighted_list;
151149
if (request.params_size() == 3) {
152150
size_t weight_list_size = request.params(2).size() / sizeof(bool);
@@ -177,11 +175,9 @@ int32_t GraphBrpcService::remove_graph_node(Table *table,
177175
return 0;
178176
}
179177
int idx_ = *(int *)(request.params(0).c_str());
180-
size_t node_num = request.params(1).size() / sizeof(int64_t);
181-
int64_t *node_data = (int64_t *)(request.params(1).c_str());
182-
// size_t node_num = request.params(0).size() / sizeof(int64_t);
183-
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
184-
std::vector<int64_t> node_ids(node_data, node_data + node_num);
178+
size_t node_num = request.params(1).size() / sizeof(uint64_t);
179+
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
180+
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
185181

186182
((GraphTable *)table)->remove_graph_node(idx_, node_ids);
187183
return 0;
@@ -215,11 +211,6 @@ int32_t GraphBrpcService::Initialize() {
215211
&GraphBrpcService::graph_set_node_feat;
216212
_service_handler_map[PS_GRAPH_SAMPLE_NODES_FROM_ONE_SERVER] =
217213
&GraphBrpcService::sample_neighbors_across_multi_servers;
218-
// _service_handler_map[PS_GRAPH_USE_NEIGHBORS_SAMPLE_CACHE] =
219-
// &GraphBrpcService::use_neighbors_sample_cache;
220-
// _service_handler_map[PS_GRAPH_LOAD_GRAPH_SPLIT_CONFIG] =
221-
// &GraphBrpcService::load_graph_split_config;
222-
// shard初始化,server启动后才可从env获取到server_list的shard信息
223214
InitializeShardInfo();
224215

225216
return 0;
@@ -384,9 +375,6 @@ int32_t GraphBrpcService::pull_graph_list(Table *table,
384375
int start = *(int *)(request.params(2).c_str());
385376
int size = *(int *)(request.params(3).c_str());
386377
int step = *(int *)(request.params(4).c_str());
387-
// int start = *(int *)(request.params(0).c_str());
388-
// int size = *(int *)(request.params(1).c_str());
389-
// int step = *(int *)(request.params(2).c_str());
390378
std::unique_ptr<char[]> buffer;
391379
int actual_size;
392380
((GraphTable *)table)
@@ -406,14 +394,10 @@ int32_t GraphBrpcService::graph_random_sample_neighbors(
406394
return 0;
407395
}
408396
int idx_ = *(int *)(request.params(0).c_str());
409-
size_t node_num = request.params(1).size() / sizeof(int64_t);
410-
int64_t *node_data = (int64_t *)(request.params(1).c_str());
411-
int sample_size = *(int64_t *)(request.params(2).c_str());
397+
size_t node_num = request.params(1).size() / sizeof(uint64_t);
398+
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
399+
int sample_size = *(int *)(request.params(2).c_str());
412400
bool need_weight = *(bool *)(request.params(3).c_str());
413-
// size_t node_num = request.params(0).size() / sizeof(int64_t);
414-
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
415-
// int sample_size = *(int64_t *)(request.params(1).c_str());
416-
// bool need_weight = *(bool *)(request.params(2).c_str());
417401
std::vector<std::shared_ptr<char>> buffers(node_num);
418402
std::vector<int> actual_sizes(node_num, 0);
419403
((GraphTable *)table)
@@ -433,7 +417,7 @@ int32_t GraphBrpcService::graph_random_sample_nodes(
433417
brpc::Controller *cntl) {
434418
int type_id = *(int *)(request.params(0).c_str());
435419
int idx_ = *(int *)(request.params(1).c_str());
436-
size_t size = *(int64_t *)(request.params(2).c_str());
420+
size_t size = *(uint64_t *)(request.params(2).c_str());
437421
// size_t size = *(int64_t *)(request.params(0).c_str());
438422
std::unique_ptr<char[]> buffer;
439423
int actual_size;
@@ -459,11 +443,9 @@ int32_t GraphBrpcService::graph_get_node_feat(Table *table,
459443
return 0;
460444
}
461445
int idx_ = *(int *)(request.params(0).c_str());
462-
size_t node_num = request.params(1).size() / sizeof(int64_t);
463-
int64_t *node_data = (int64_t *)(request.params(1).c_str());
464-
// size_t node_num = request.params(0).size() / sizeof(int64_t);
465-
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
466-
std::vector<int64_t> node_ids(node_data, node_data + node_num);
446+
size_t node_num = request.params(1).size() / sizeof(uint64_t);
447+
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
448+
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
467449

468450
std::vector<std::string> feature_names =
469451
paddle::string::split_string<std::string>(request.params(2), "\t");
@@ -497,22 +479,15 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
497479
}
498480

499481
int idx_ = *(int *)(request.params(0).c_str());
500-
size_t node_num = request.params(1).size() / sizeof(int64_t),
482+
size_t node_num = request.params(1).size() / sizeof(uint64_t),
501483
size_of_size_t = sizeof(size_t);
502-
int64_t *node_data = (int64_t *)(request.params(1).c_str());
503-
int sample_size = *(int64_t *)(request.params(2).c_str());
504-
bool need_weight = *(int64_t *)(request.params(3).c_str());
505-
506-
// size_t node_num = request.params(0).size() / sizeof(int64_t),
507-
// size_of_size_t = sizeof(size_t);
508-
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
509-
// int sample_size = *(int64_t *)(request.params(1).c_str());
510-
// bool need_weight = *(int64_t *)(request.params(2).c_str());
511-
// std::vector<int64_t> res = ((GraphTable
512-
// *)table).filter_out_non_exist_nodes(node_data, sample_size);
484+
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
485+
int sample_size = *(int *)(request.params(2).c_str());
486+
bool need_weight = *(bool *)(request.params(3).c_str());
487+
513488
std::vector<int> request2server;
514489
std::vector<int> server2request(server_size, -1);
515-
std::vector<int64_t> local_id;
490+
std::vector<uint64_t> local_id;
516491
std::vector<int> local_query_idx;
517492
size_t rank = GetRank();
518493
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
@@ -535,7 +510,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
535510
std::vector<std::shared_ptr<char>> local_buffers;
536511
std::vector<int> local_actual_sizes;
537512
std::vector<size_t> seq;
538-
std::vector<std::vector<int64_t>> node_id_buckets(request_call_num);
513+
std::vector<std::vector<uint64_t>> node_id_buckets(request_call_num);
539514
std::vector<std::vector<int>> query_idx_buckets(request_call_num);
540515
for (int query_idx = 0; query_idx < node_num; ++query_idx) {
541516
int server_index =
@@ -624,7 +599,7 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers(
624599

625600
closure->request(request_idx)
626601
->add_params((char *)node_id_buckets[request_idx].data(),
627-
sizeof(int64_t) * node_num);
602+
sizeof(uint64_t) * node_num);
628603
closure->request(request_idx)
629604
->add_params((char *)&sample_size, sizeof(int));
630605
closure->request(request_idx)
@@ -661,11 +636,9 @@ int32_t GraphBrpcService::graph_set_node_feat(Table *table,
661636
}
662637
int idx_ = *(int *)(request.params(0).c_str());
663638

664-
// size_t node_num = request.params(0).size() / sizeof(int64_t);
665-
// int64_t *node_data = (int64_t *)(request.params(0).c_str());
666-
size_t node_num = request.params(1).size() / sizeof(int64_t);
667-
int64_t *node_data = (int64_t *)(request.params(1).c_str());
668-
std::vector<int64_t> node_ids(node_data, node_data + node_num);
639+
size_t node_num = request.params(1).size() / sizeof(uint64_t);
640+
uint64_t *node_data = (uint64_t *)(request.params(1).c_str());
641+
std::vector<uint64_t> node_ids(node_data, node_data + node_num);
669642

670643
// std::vector<std::string> feature_names =
671644
// paddle::string::split_string<std::string>(request.params(1), "\t");

paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,14 @@ class GraphPyService {
8181

8282
graph_proto->set_table_name("cpu_graph_table");
8383
graph_proto->set_use_cache(false);
84-
for (int i = 0; i < id_to_edge.size(); i++)
84+
for (int i = 0; i < (int)id_to_edge.size(); i++)
8585
graph_proto->add_edge_types(id_to_edge[i]);
86-
for (int i = 0; i < id_to_feature.size(); i++) {
86+
for (int i = 0; i < (int)id_to_feature.size(); i++) {
8787
graph_proto->add_node_types(id_to_feature[i]);
8888
auto feat_node = id_to_feature[i];
8989
::paddle::distributed::GraphFeature* g_f =
9090
graph_proto->add_graph_feature();
91-
for (int x = 0; x < table_feat_conf_feat_name[i].size(); x++) {
91+
for (int x = 0; x < (int)table_feat_conf_feat_name[i].size(); x++) {
9292
g_f->add_name(table_feat_conf_feat_name[i][x]);
9393
g_f->add_dtype(table_feat_conf_feat_dtype[i][x]);
9494
g_f->add_shape(table_feat_conf_feat_shape[i][x]);

0 commit comments

Comments
 (0)