Skip to content

Commit 827fefa

Browse files
committed
feat(misc): Profiler support
use --enable_profiling=MODE to enable, currently support torch_profile and nvtx (use with NVIDIA Nsight system) mode
1 parent 974d775 commit 827fefa

File tree

8 files changed

+331
-4
lines changed

8 files changed

+331
-4
lines changed

lightllm/server/api_cli.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,4 +572,19 @@ def make_argument_parser() -> argparse.ArgumentParser:
572572
default=False,
573573
help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""",
574574
)
575+
parser.add_argument(
576+
"--enable_profiling",
577+
type=str,
578+
choices=["torch_profiler", "nvtx"],
579+
default=None,
580+
help="""Enable profiler support.
581+
This will expose '/profiler_start' and '/profiler_stop' API,
582+
below profiling features will only be enabled in this range.
583+
Options:
584+
'torch_profiler': will setup torch.profiler.profile(), trace files will be saved to './trace',
585+
or set by 'LIGHTLLM_TRACE_DIR' env;
586+
'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System
587+
(you should set it up by yourself).
588+
A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""",
589+
)
575590
return parser

lightllm/server/api_http.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,24 @@ async def kv_move_status(websocket: WebSocket):
335335
return
336336

337337

338+
@app.get("/profiler_start")
339+
async def profiler_start() -> Response:
340+
if g_objs.args.enable_profiling:
341+
await g_objs.httpserver_manager.profiler_cmd("start")
342+
return JSONResponse({"status": "ok"})
343+
else:
344+
return JSONResponse({"message": "Profiling support not enabled"}, status_code=400)
345+
346+
347+
@app.get("/profiler_stop")
348+
async def profiler_stop() -> Response:
349+
if g_objs.args.enable_profiling:
350+
await g_objs.httpserver_manager.profiler_cmd("stop")
351+
return JSONResponse({"status": "ok"})
352+
else:
353+
return JSONResponse({"message": "Profiling support not enabled"}, status_code=400)
354+
355+
338356
@app.on_event("shutdown")
339357
async def shutdown():
340358
logger.info("Received signal to shutdown. Performing graceful shutdown...")

lightllm/server/httpserver/manager.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from frozendict import frozendict
1414

1515
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
16-
from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator
16+
from typing import Literal, Union, List, Tuple, Dict, Optional, AsyncGenerator
1717
from websockets import ClientConnection
1818
from fastapi import Request
1919
from ..tokenizer import get_tokenizer
@@ -35,6 +35,7 @@
3535
from lightllm.utils.config_utils import get_vocab_size
3636
from lightllm.utils.envs_utils import get_unique_server_name
3737
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
38+
from lightllm.utils.profiler import ProfilerCmd
3839
from rpyc.utils.classic import obtain
3940

4041
logger = init_logger(__name__)
@@ -650,6 +651,16 @@ async def abort(self, group_req_id: int) -> bool:
650651
logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}")
651652
return True
652653

654+
async def profiler_cmd(self, cmd: Literal["start", "stop"]):
655+
receivers = [self.send_to_router]
656+
if self.pd_mode.is_P_or_NORMAL() and self.enable_multimodal:
657+
receivers.append(self.send_to_visual)
658+
for receiver in receivers:
659+
receiver.send_pyobj(
660+
ProfilerCmd(cmd),
661+
protocol=pickle.HIGHEST_PROTOCOL,
662+
)
663+
653664
async def recycle_resource_loop(self):
654665
pre_time_mark = time.time()
655666

lightllm/server/router/manager.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
2727
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
2828
from lightllm.utils.log_utils import init_logger, log_time_ready
29+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
2930
from lightllm.server.router.token_load import TokenLoad
3031
from lightllm.server.metrics.manager import MetricClient
3132
from lightllm.common.basemodel.infer_lock import g_router_lock
@@ -106,6 +107,9 @@ def __init__(self, args: StartArgs):
106107
if not self.args.enable_cpu_cache
107108
else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False)
108109
)
110+
111+
profiler_mode = args.enable_profiling
112+
self.profiler = ProcessProfiler(mode=profiler_mode, name="lightllm-router") if profiler_mode else None
109113
return
110114

111115
async def wait_to_model_ready(self):
@@ -504,16 +508,28 @@ def _multinode_tp_generate_new_batch(self):
504508
raise e
505509
return
506510

511+
async def _profiler_cmd(self, cmd_obj: ProfilerCmd):
512+
self.profiler.cmd(cmd_obj)
513+
514+
cmd = ProfilerCmd(cmd=cmd_obj.cmd)
515+
while not self.shm_reqs_io_buffer.is_empty():
516+
await asyncio.sleep(0.02)
517+
518+
self.shm_reqs_io_buffer.write_obj([cmd])
519+
self.shm_reqs_io_buffer.set_ready()
520+
507521
async def _recv_new_reqs_and_schedule(self):
508522
if not hasattr(self, "recv_max_count"):
509523
self.recv_max_count = 64
510524

511525
try:
512526
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
513527
for _ in range(self.recv_max_count):
514-
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
528+
recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
515529
if isinstance(recv_req, GroupReqIndexes):
516530
self._add_req(recv_req)
531+
elif isinstance(recv_req, ProfilerCmd):
532+
await self._profiler_cmd(recv_req)
517533
else:
518534
assert False, f"Error Req Inf {recv_req}"
519535

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
4040
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
4141
from .multi_level_kv_cache import MultiLevelKvCacheModule
42+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
4243

4344

4445
class ModeBackend:
@@ -231,6 +232,10 @@ def init_model(self, kvargs):
231232
if self.args.mtp_mode:
232233
self.init_mtp_draft_model(kvargs)
233234

235+
prof_name = f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}"
236+
prof_mode = self.args.enable_profiling
237+
self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name, use_multi_thread=True) if prof_mode else None
238+
234239
# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
235240
# 可以降低 cpu overhead,大幅提升gpu得使用率。
236241
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True)
@@ -343,6 +348,10 @@ def _try_read_new_reqs(self):
343348
self._try_read_new_reqs_multinode_tp()
344349
else:
345350
self._try_read_new_reqs_normal()
351+
352+
# on each loop thread
353+
if self.profiler is not None:
354+
self.profiler.multi_thread_helper()
346355
return
347356

348357
def _try_read_new_reqs_normal(self):
@@ -408,6 +417,8 @@ def _read_reqs_buffer_and_init_reqs(self):
408417
if obj.req_id in g_infer_context.requests_mapping:
409418
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
410419
req.infer_aborted = True
420+
elif isinstance(obj, ProfilerCmd):
421+
self.profiler.cmd(obj)
411422
else:
412423
assert False, f"error type {type(obj)}"
413424
if init_reqs:

lightllm/server/visualserver/manager.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pickle
88
import inspect
99
import setproctitle
10-
from typing import List
10+
from typing import List, Union
1111
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
1212
from lightllm.server.core.objs import ShmReqManager, StartArgs
1313

@@ -18,6 +18,7 @@
1818
from lightllm.utils.graceful_utils import graceful_registry
1919
from lightllm.utils.process_check import start_parent_check_thread
2020
from lightllm.utils.envs_utils import get_unique_server_name
21+
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
2122
from rpyc.utils.classic import obtain
2223

2324

@@ -59,6 +60,8 @@ def __init__(
5960
self.visual_model_rpc_ports = visual_model_rpc_ports
6061
self.send_batch_size = args.visual_send_batch_size
6162
self.shm_req_manager = ShmReqManager()
63+
prof_mode = args.enable_profiling
64+
self.profiler = ProcessProfiler(prof_mode, name="lightllm-visual_server") if prof_mode else None
6265

6366
async def wait_to_model_ready(self):
6467

@@ -185,9 +188,17 @@ async def loop_for_netio_req(self):
185188
while True:
186189
try:
187190
for _ in range(self.visual_recv_max_count):
188-
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
191+
recv_req: GroupReqIndexes | ProfilerCmd = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
189192
if isinstance(recv_req, GroupReqIndexes):
190193
self.waiting_reqs.append(recv_req)
194+
elif isinstance(recv_req, ProfilerCmd):
195+
self.profiler.cmd(recv_req)
196+
tasks = []
197+
for dp in range(self.vit_dp):
198+
for tp in range(self.vit_tp):
199+
task = asyncio.create_task(self.model_rpcs[dp][tp].profiler_cmd(recv_req))
200+
tasks.append(task)
201+
await asyncio.gather(*tasks)
191202
else:
192203
assert False, f"Error Req Inf {recv_req}"
193204
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from lightllm.utils.dist_utils import init_vision_distributed_env
2525
from lightllm.utils.graceful_utils import graceful_registry
2626
from lightllm.utils.envs_utils import get_env_start_args
27+
from lightllm.utils.profiler import ProcessProfiler
2728

2829

2930
class VisualModelRpcServer(rpyc.Service):
@@ -43,6 +44,9 @@ def exposed_init_model(self, kvargs):
4344
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
4445
self.data_type = kvargs["data_type"]
4546

47+
prof_mode = get_env_start_args().enable_profiling
48+
prof_name = f"lightllm-visual-vit_dp{self.dp_rank_id}_tp{self.tp_rank_id}"
49+
self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name) if prof_mode else None
4650
init_vision_distributed_env(kvargs)
4751
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)
4852

@@ -116,6 +120,10 @@ def exposed_encode(self, images: List[ImageItem]):
116120
self.cache_client.root.set_items_embed(ids_to_set)
117121
return
118122

123+
def exposed_profiler_cmd(self, cmd_obj):
124+
cmd_obj = obtain(cmd_obj)
125+
self.profiler.cmd(cmd_obj)
126+
119127

120128
class VisualModelRpcClient:
121129
def __init__(self, model_rpc, vit_tp, rpc_server_process=None):
@@ -138,9 +146,11 @@ async def _func(*args, **kwargs):
138146

139147
self._init_model = async_wrap(self.model.init_model)
140148
self._encode = async_wrap(self.model.encode)
149+
self._profiler_cmd = async_wrap(self.model.profiler_cmd)
141150
else:
142151
self._init_model = self.model.exposed_init_model
143152
self._encode = self.model.exposed_encode
153+
self._profiler_cmd = self.model.exposed_profiler_cmd
144154
return
145155

146156
async def init_model(self, kvargs):
@@ -158,6 +168,14 @@ async def encode(self, images: List[ImageItem]):
158168
else:
159169
return ans
160170

171+
async def profiler_cmd(self, cmd_obj):
172+
ans: rpyc.AsyncResult = self._profiler_cmd(cmd_obj)
173+
if self.use_rpc:
174+
await ans
175+
return
176+
else:
177+
return
178+
161179

162180
def _init_env(port, device_id):
163181
# 注册graceful 退出的处理

0 commit comments

Comments
 (0)