@@ -35,43 +35,91 @@ int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
3535 return id % shard_num / shard_per_server;
3636}
3737// char* &buffer,int &actual_size
38- std::future<int32_t > GraphBrpcClient::sample (
39- uint32_t table_id, uint64_t node_id, int sample_size,
40- std::vector<std::pair<uint64_t , float >> &res) {
41- int server_index = get_server_index_by_id (node_id);
42- DownpourBrpcClosure *closure = new DownpourBrpcClosure (1 , [&](void *done) {
38+ std::future<int32_t > GraphBrpcClient::batch_sample (uint32_t table_id,
39+ std::vector<uint64_t > node_ids, int sample_size,
40+ std::vector<std::vector<std::pair<uint64_t , float > > > &res) {
41+
42+ std::vector<int > request2server;
43+ std::vector<int > server2request (server_size, -1 );
44+ res.clear ();
45+ for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx){
46+ int server_index = get_server_index_by_id (node_ids[query_idx]);
47+ if (server2request[server_index] == -1 ){
48+ server2request[server_index] = request2server.size ();
49+ request2server.push_back (server_index);
50+ }
51+ // res.push_back(std::vector<GraphNode>());
52+ res.push_back (std::vector<std::pair<uint64_t , float >>());
53+ }
54+ size_t request_call_num = request2server.size ();
55+ std::vector<std::vector<uint64_t > > node_id_buckets (request_call_num);
56+ std::vector<std::vector<int > > query_idx_buckets (request_call_num);
57+ for (int query_idx = 0 ; query_idx < node_ids.size (); ++query_idx){
58+ int server_index = get_server_index_by_id (node_ids[query_idx]);
59+ int request_idx = server2request[server_index];
60+ node_id_buckets[request_idx].push_back (node_ids[query_idx]);
61+ query_idx_buckets[request_idx].push_back (query_idx);
62+ }
63+
64+ DownpourBrpcClosure *closure = new DownpourBrpcClosure (request_call_num, [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
4365 int ret = 0 ;
4466 auto *closure = (DownpourBrpcClosure *)done;
45- if (closure->check_response (0 , PS_GRAPH_SAMPLE) != 0 ) {
46- ret = -1 ;
47- } else {
48- auto &res_io_buffer = closure->cntl (0 )->response_attachment ();
49- butil::IOBufBytesIterator io_buffer_itr (res_io_buffer);
50- size_t bytes_size = io_buffer_itr.bytes_left ();
51- char *buffer = new char [bytes_size];
52- io_buffer_itr.copy_and_forward ((void *)(buffer), bytes_size);
53- int offset = 0 ;
54- while (offset < bytes_size) {
55- res.push_back ({*(uint64_t *)(buffer + offset),
56- *(float *)(buffer + offset + GraphNode::id_size)});
57- offset += GraphNode::id_size + GraphNode::weight_size;
67+ int fail_num = 0 ;
68+ for (int request_idx = 0 ; request_idx < request_call_num; ++request_idx){
69+ if (closure->check_response (request_idx, PS_GRAPH_SAMPLE) != 0 ) {
70+ ++fail_num;
71+ } else {
72+ VLOG (0 ) << " check sample response: "
73+ << " " << closure->check_response (request_idx, PS_GRAPH_SAMPLE);
74+ auto &res_io_buffer = closure->cntl (request_idx)->response_attachment ();
75+ butil::IOBufBytesIterator io_buffer_itr (res_io_buffer);
76+ size_t bytes_size = io_buffer_itr.bytes_left ();
77+ char *buffer = new char [bytes_size];
78+ io_buffer_itr.copy_and_forward ((void *)(buffer), bytes_size);
79+
80+ size_t node_num = *(size_t *)buffer;
81+ int *actual_sizes = (int *)(buffer + sizeof (size_t ));
82+ char *node_buffer = buffer + sizeof (size_t ) + sizeof (int ) * node_num;
83+
84+ int offset = 0 ;
85+ for (size_t node_idx = 0 ; node_idx < node_num; ++node_idx){
86+ int query_idx = query_idx_buckets.at (request_idx).at (node_idx);
87+ int actual_size = actual_sizes[node_idx];
88+ int start = 0 ;
89+ while (start < actual_size) {
90+ res[query_idx].push_back ({*(uint64_t *)(node_buffer + offset + start),
91+ *(float *)(node_buffer + offset + start + GraphNode::id_size)});
92+ start += GraphNode::id_size + GraphNode::weight_size;
93+ }
94+ offset += actual_size;
95+ }
96+ }
97+ if (fail_num == request_call_num){
98+ ret = -1 ;
5899 }
59100 }
60101 closure->set_promise_value (ret);
61102 });
103+
62104 auto promise = std::make_shared<std::promise<int32_t >>();
63105 closure->add_promise (promise);
64106 std::future<int > fut = promise->get_future ();
65- ;
66- closure->request (0 )->set_cmd_id (PS_GRAPH_SAMPLE);
67- closure->request (0 )->set_table_id (table_id);
68- closure->request (0 )->set_client_id (_client_id);
69- closure->request (0 )->add_params ((char *)&node_id, sizeof (uint64_t ));
70- closure->request (0 )->add_params ((char *)&sample_size, sizeof (int ));
71- PsService_Stub rpc_stub (get_cmd_channel (server_index));
72- closure->cntl (0 )->set_log_id (butil::gettimeofday_ms ());
73- rpc_stub.service (closure->cntl (0 ), closure->request (0 ), closure->response (0 ),
74- closure);
107+
108+ for (int request_idx = 0 ; request_idx < request_call_num; ++request_idx){
109+ int server_index = request2server[request_idx];
110+ closure->request (request_idx)->set_cmd_id (PS_GRAPH_SAMPLE);
111+ closure->request (request_idx)->set_table_id (table_id);
112+ closure->request (request_idx)->set_client_id (_client_id);
113+ // std::string type_str = GraphNode::node_type_to_string(type);
114+ size_t node_num = node_id_buckets[request_idx].size ();
115+
116+ closure->request (request_idx)->add_params ((char *)node_id_buckets[request_idx].data (), sizeof (uint64_t )*node_num);
117+ closure->request (request_idx)->add_params ((char *)&sample_size, sizeof (int ));
118+ PsService_Stub rpc_stub (get_cmd_channel (server_index));
119+ closure->cntl (request_idx)->set_log_id (butil::gettimeofday_ms ());
120+ rpc_stub.service (closure->cntl (request_idx), closure->request (request_idx), closure->response (request_idx),
121+ closure);
122+ }
75123
76124 return fut;
77125}
@@ -124,4 +172,4 @@ int32_t GraphBrpcClient::initialize() {
124172 return 0 ;
125173}
126174}
127- }
175+ }
0 commit comments