Skip to content

Commit 97b97e3

Browse files
committed
use Cython improve post_handle performance
1 parent 42e8199 commit 97b97e3

File tree

6 files changed

+203
-2
lines changed

6 files changed

+203
-2
lines changed

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@ def update_finish_status(self, eos_ids):
338338
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
339339
elif (
340340
self.cur_output_len > 0
341-
and self.get_last_gen_token() in eos_ids
342341
and self.sampling_param.shm_param.ignore_eos is False
342+
and self.get_last_gen_token() in eos_ids
343343
):
344344
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
345345
elif self.cur_output_len >= self.sampling_param.shm_param.max_new_tokens:

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import rpyc
55
import torch
66
import socket
7+
import time
78
from datetime import timedelta
89
from typing import Dict, List, Tuple, Callable, Optional
910
from transformers.configuration_utils import PretrainedConfig
@@ -374,6 +375,81 @@ def _post_handle(
374375
is_chuncked_mode: bool,
375376
do_filter_finished_reqs: bool,
376377
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
378+
) -> List[int]:
379+
"""
380+
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
381+
约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
382+
"""
383+
if not hasattr(self, "_post_handle_impl"):
384+
try:
385+
finished_req_ids = self._fast_post_handle(
386+
run_reqs,
387+
next_token_ids,
388+
next_token_logprobs,
389+
is_chuncked_mode,
390+
do_filter_finished_reqs,
391+
extra_post_req_handle_func,
392+
)
393+
self._post_handle_impl = self._fast_post_handle
394+
self.logger.info("use _fast_post_handle")
395+
return finished_req_ids
396+
except:
397+
finished_req_ids = self._python_post_handle(
398+
run_reqs,
399+
next_token_ids,
400+
next_token_logprobs,
401+
is_chuncked_mode,
402+
do_filter_finished_reqs,
403+
extra_post_req_handle_func,
404+
)
405+
self.logger.info("use _python_post_handle")
406+
self._post_handle_impl = self._python_post_handle
407+
return finished_req_ids
408+
else:
409+
return self._post_handle_impl(
410+
run_reqs,
411+
next_token_ids,
412+
next_token_logprobs,
413+
is_chuncked_mode,
414+
do_filter_finished_reqs,
415+
extra_post_req_handle_func,
416+
)
417+
418+
def _fast_post_handle(
419+
self,
420+
run_reqs: List[InferReq],
421+
next_token_ids,
422+
next_token_logprobs,
423+
is_chuncked_mode: bool,
424+
do_filter_finished_reqs: bool,
425+
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
426+
):
427+
from . import cython_fast_impl
428+
429+
start = time.time()
430+
finished_req_ids = cython_fast_impl.fast_post_handle(
431+
self,
432+
run_reqs,
433+
next_token_ids,
434+
next_token_logprobs,
435+
is_chuncked_mode,
436+
do_filter_finished_reqs,
437+
extra_post_req_handle_func,
438+
)
439+
cost_time = time.time() - start
440+
if self.is_master_in_dp and cost_time > 0.001:
441+
self.logger.info(f"post handle cost time {cost_time} s, batch_size: {len(run_reqs)}")
442+
return finished_req_ids
443+
444+
# 一些可以复用的通用功能函数
445+
def _python_post_handle(
446+
self,
447+
run_reqs: List[InferReq],
448+
next_token_ids,
449+
next_token_logprobs,
450+
is_chuncked_mode: bool,
451+
do_filter_finished_reqs: bool,
452+
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
377453
) -> List[int]:
378454
"""
379455
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import cython
2+
from typing import List, Optional, Callable
3+
from ..infer_batch import InferReq, FinishStatus
4+
from .base_backend import ModeBackend
5+
6+
7+
def __update_finish_status(self: InferReq, gen_new_token_id:int, eos_ids: List[int]):
8+
# stop way 1
9+
for stop_token_ids in self.stop_sequences:
10+
stop_len = len(stop_token_ids)
11+
output_len = self.cur_output_len
12+
if stop_len > 0 and output_len >= stop_len:
13+
total_len = self.shm_req.input_len + output_len
14+
tail_token_ids = self.shm_req.shm_prompt_ids.arr[(total_len - stop_len) : total_len]
15+
if all(tail_token_ids[i] == stop_token_ids[i] for i in range(stop_len)):
16+
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
17+
return
18+
19+
# stop way 2
20+
shm_param = self.sampling_param.shm_param
21+
if (self.cur_output_len > 0
22+
and shm_param.ignore_eos is False
23+
and gen_new_token_id in eos_ids
24+
):
25+
self.finish_status.set_status(FinishStatus.FINISHED_STOP)
26+
return
27+
28+
# stop way 3
29+
if self.cur_output_len >= shm_param.max_new_tokens:
30+
self.finish_status.set_status(FinishStatus.FINISHED_LENGTH)
31+
return
32+
33+
34+
# @cython.boundcheck(False)
35+
# @cython.wraparound(False)
36+
def fast_post_handle(
37+
self: ModeBackend,
38+
run_reqs: List[InferReq],
39+
next_token_ids_,
40+
next_token_logprobs_,
41+
is_chuncked_mode: bool,
42+
do_filter_finished_reqs: bool,
43+
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
44+
) -> List[int]:
45+
"""
46+
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于
47+
约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。
48+
"""
49+
from lightllm.server.router.model_infer.infer_batch import g_infer_context
50+
51+
finished_req_ids = [0 for _ in range(len(run_reqs))]
52+
finished_req_ids.clear()
53+
next_token_ids: cython.longlong[:] = cython.declare(cython.longlong[:], next_token_ids_)
54+
next_token_logprobs: cython.float[:] = cython.declare(cython.float[:], next_token_logprobs_)
55+
is_master_in_dp : cython.bint = self.is_master_in_dp
56+
is_chuncked_mode : cython.bint = is_chuncked_mode
57+
58+
i : cython.Py_ssize_t
59+
for i in range(len(run_reqs)):
60+
req_obj: InferReq = run_reqs[i]
61+
shm_req = req_obj.shm_req
62+
next_token_id: cython.int = next_token_ids[i]
63+
next_token_logprob: cython.float = next_token_logprobs[i]
64+
cur_total_len = shm_req.input_len + req_obj.cur_output_len
65+
66+
if is_chuncked_mode:
67+
new_kv_len = min(cur_total_len, req_obj.cur_kv_len + shm_req.chunked_prefill_size)
68+
else:
69+
new_kv_len = cur_total_len
70+
71+
req_obj.cur_kv_len = new_kv_len
72+
if is_master_in_dp:
73+
shm_req.shm_cur_kv_len = req_obj.cur_kv_len
74+
75+
# 这个地方主要是为了提前判断是否存在abort的情况,如果abort了
76+
# 直接将请求放入finished 处理队列中。
77+
if shm_req.router_aborted:
78+
finished_req_ids.append(shm_req.request_id)
79+
continue
80+
81+
# 对于没有到达需要输出 token 阶段的请求,直接略过
82+
if req_obj.cur_kv_len < cur_total_len:
83+
continue
84+
85+
# 将生成的下一个token的信息写入到管理对象中。
86+
gen_token_index = cur_total_len
87+
shm_req.shm_prompt_ids.arr[gen_token_index] = next_token_id
88+
shm_req.shm_logprobs.arr[gen_token_index] = next_token_logprob
89+
req_obj.cur_output_len += 1
90+
91+
req_obj.out_token_id_count[next_token_id] += 1
92+
__update_finish_status(req_obj, next_token_id, self.eos_id)
93+
94+
if extra_post_req_handle_func is not None:
95+
extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob)
96+
97+
# 判断是否已经满足生成结束条件。
98+
is_finished = req_obj.finish_status.is_finished()
99+
if is_finished or req_obj.shm_req.router_aborted:
100+
finished_req_ids.append(shm_req.request_id)
101+
102+
if is_master_in_dp:
103+
# shm_cur_kv_len shm_cur_output_len 是 router 调度进程需要读的信息
104+
# finish_token_index finish_status candetoken_out_len 是
105+
# detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。
106+
shm_req.shm_cur_output_len = req_obj.cur_output_len
107+
108+
if is_finished:
109+
shm_req.finish_token_index = gen_token_index
110+
shm_req.finish_status = req_obj.finish_status
111+
112+
shm_req.candetoken_out_len = req_obj.cur_output_len
113+
114+
if do_filter_finished_reqs:
115+
g_infer_context.filter(finished_req_ids)
116+
117+
return finished_req_ids

lightllm/server/router/model_infer/mode_backend/generic_post_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def sample(logits, reqs, eos_id: List[int] = [2]):
6565
int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64)
6666
int64_batch_next_token_ids[:] = batch_next_token_ids
6767
batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1))
68-
return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1)
68+
return int64_batch_next_token_ids.view(-1), batch_next_token_probs.view(-1)
6969
else:
7070
assert False, "dead path"
7171

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,4 @@ flashinfer-python==0.2.4
8888
sgl-kernel
8989
httpx==0.28.1
9090
librosa==0.11.0
91+
Cython

setup.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from setuptools import setup, find_packages
2+
from Cython.Build import cythonize
23

34
package_data = {"lightllm": ["common/all_kernel_configs/*/*.json"]}
45
setup(
@@ -28,4 +29,10 @@
2829
"triton",
2930
],
3031
package_data=package_data,
32+
ext_modules=cythonize(
33+
[
34+
"lightllm/server/router/model_infer/mode_backend/cython_fast_impl.pyx",
35+
]
36+
),
37+
zip_safe=False,
3138
)

0 commit comments

Comments
 (0)