|
26 | 26 | #define MAX_BSZ 512 |
27 | 27 | // #define GET_OUTPUT_DEBUG |
28 | 28 | struct msgdata { |
29 | | - long mtype; |
30 | | - int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens |
| 29 | + long mtype; |
| 30 | + int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens |
31 | 31 | }; |
32 | 32 |
|
33 | 33 | void GetOutput(const paddle::Tensor& x, |
34 | 34 | int64_t rank_id, |
35 | 35 | bool wait_flag, |
36 | 36 | int msg_queue_id) { |
37 | | - if (rank_id > 0) { |
38 | | - return; |
39 | | - } |
40 | | - static struct msgdata msg_rcv; |
41 | | - if (const char* inference_msg_queue_id_env_p = |
42 | | - std::getenv("INFERENCE_MSG_QUEUE_ID")) { |
43 | | - std::string inference_msg_queue_id_env_str( |
44 | | - inference_msg_queue_id_env_p); |
45 | | - int inference_msg_queue_id_from_env = |
46 | | - std::stoi(inference_msg_queue_id_env_str); |
| 37 | + if (rank_id > 0) { |
| 38 | + return; |
| 39 | + } |
| 40 | + static struct msgdata msg_rcv; |
| 41 | + if (const char* inference_msg_queue_id_env_p = |
| 42 | + std::getenv("INFERENCE_MSG_QUEUE_ID")) { |
| 43 | + std::string inference_msg_queue_id_env_str(inference_msg_queue_id_env_p); |
| 44 | + int inference_msg_queue_id_from_env = |
| 45 | + std::stoi(inference_msg_queue_id_env_str); |
47 | 46 | #ifdef GET_OUTPUT_DEBUG |
48 | | - std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " |
49 | | - << inference_msg_queue_id_from_env << std::endl; |
| 47 | + std::cout << "Your INFERENCE_MSG_QUEUE_ID is: " |
| 48 | + << inference_msg_queue_id_from_env << std::endl; |
50 | 49 | #endif |
51 | | - msg_queue_id = inference_msg_queue_id_from_env; |
52 | | - } |
53 | | - static key_t key = ftok("/dev/shm", msg_queue_id); |
54 | | - static int msgid = msgget(key, IPC_CREAT | 0666); |
| 50 | + msg_queue_id = inference_msg_queue_id_from_env; |
| 51 | + } |
| 52 | + static key_t key = ftok("/dev/shm", msg_queue_id); |
| 53 | + static int msgid = msgget(key, IPC_CREAT | 0666); |
55 | 54 |
|
56 | 55 | #ifdef GET_OUTPUT_DEBUG |
57 | | - std::cout << "get_output_key: " << key << std::endl; |
58 | | - std::cout << "get_output msgid: " << msgid << std::endl; |
| 56 | + std::cout << "get_output_key: " << key << std::endl; |
| 57 | + std::cout << "get_output msgid: " << msgid << std::endl; |
59 | 58 | #endif |
60 | 59 |
|
61 | | - int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>()); |
62 | | - int ret = -1; |
63 | | - if (!wait_flag) { |
64 | | - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); |
65 | | - } else { |
66 | | - ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); |
67 | | - } |
68 | | - if (ret == -1) { |
69 | | - out_data[0] = -2; |
70 | | - out_data[1] = 0; |
71 | | - return; |
72 | | - } |
73 | | - int bsz = msg_rcv.mtext[1]; |
| 60 | + int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>()); |
| 61 | + int ret = -1; |
| 62 | + if (!wait_flag) { |
| 63 | + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); |
| 64 | + } else { |
| 65 | + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); |
| 66 | + } |
| 67 | + if (ret == -1) { |
| 68 | + out_data[0] = -2; |
| 69 | + out_data[1] = 0; |
| 70 | + return; |
| 71 | + } |
| 72 | + int bsz = msg_rcv.mtext[1]; |
74 | 73 |
|
75 | | - for (int64_t i = 0; i < bsz + 2; i++) { |
76 | | - out_data[i] = (int64_t)msg_rcv.mtext[i]; |
77 | | - } |
| 74 | + for (int64_t i = 0; i < bsz + 2; i++) { |
| 75 | + out_data[i] = (int64_t)msg_rcv.mtext[i]; |
| 76 | + } |
78 | 77 | #ifdef GET_OUTPUT_DEBUG |
79 | | - std::cout << "get_output finished: " << msgid << std::endl; |
| 78 | + std::cout << "get_output finished: " << msgid << std::endl; |
80 | 79 | #endif |
81 | 80 |
|
82 | | - return; |
| 81 | + return; |
83 | 82 | } |
84 | 83 |
|
85 | 84 | void GetOutputStatic(const paddle::Tensor& x, int64_t rank_id, bool wait_flag) { |
86 | | - GetOutput(x, rank_id, wait_flag, 1); |
| 85 | + GetOutput(x, rank_id, wait_flag, 1); |
87 | 86 | } |
88 | 87 |
|
89 | 88 | void GetOutputDynamic(const paddle::Tensor& x, |
90 | 89 | int64_t rank_id, |
91 | 90 | bool wait_flag, |
92 | 91 | int msg_queue_id) { |
93 | | - GetOutput(x, rank_id, wait_flag, msg_queue_id); |
| 92 | + GetOutput(x, rank_id, wait_flag, msg_queue_id); |
94 | 93 | } |
95 | 94 |
|
96 | 95 | PD_BUILD_STATIC_OP(get_output) |
|
0 commit comments