Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,4 +572,19 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=False,
help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""",
)
parser.add_argument(
"--enable_profiling",
type=str,
choices=["torch_profiler", "nvtx"],
default=None,
help="""Enable profiler support.
This will expose '/profiler_start' and '/profiler_stop' API,
below profiling features will only be enabled in this range.
Options:
'torch_profiler': will setup torch.profiler.profile(), trace files will be saved to './trace',
or set by 'LIGHTLLM_TRACE_DIR' env;
'nvtx': will add NVTX marks for external profiler like NVIDIA Nsight System
(you should set it up by yourself).
A NVTX range named 'LIGHTLLM_PROFILE' will be added within the profiling range.""",
)
return parser
18 changes: 18 additions & 0 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,24 @@ async def kv_move_status(websocket: WebSocket):
return


@app.get("/profiler_start")
async def profiler_start() -> Response:
if g_objs.args.enable_profiling:
await g_objs.httpserver_manager.profiler_cmd("start")
return JSONResponse({"status": "ok"})
else:
return JSONResponse({"message": "Profiling support not enabled"}, status_code=400)


@app.get("/profiler_stop")
async def profiler_stop() -> Response:
if g_objs.args.enable_profiling:
await g_objs.httpserver_manager.profiler_cmd("stop")
return JSONResponse({"status": "ok"})
else:
return JSONResponse({"message": "Profiling support not enabled"}, status_code=400)


@app.on_event("shutdown")
async def shutdown():
logger.info("Received signal to shutdown. Performing graceful shutdown...")
Expand Down
13 changes: 12 additions & 1 deletion lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from frozendict import frozendict

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from typing import Union, List, Tuple, Dict, Optional, AsyncGenerator
from typing import Literal, Union, List, Tuple, Dict, Optional, AsyncGenerator
from websockets import ClientConnection
from fastapi import Request
from ..tokenizer import get_tokenizer
Expand All @@ -35,6 +35,7 @@
from lightllm.utils.config_utils import get_vocab_size
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.error_utils import NixlPrefillNodeStopGenToken
from lightllm.utils.profiler import ProfilerCmd
from rpyc.utils.classic import obtain

logger = init_logger(__name__)
Expand Down Expand Up @@ -650,6 +651,16 @@ async def abort(self, group_req_id: int) -> bool:
logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}")
return True

async def profiler_cmd(self, cmd: Literal["start", "stop"]):
receivers = [self.send_to_router]
if self.pd_mode.is_P_or_NORMAL() and self.enable_multimodal:
receivers.append(self.send_to_visual)
for receiver in receivers:
receiver.send_pyobj(
ProfilerCmd(cmd),
protocol=pickle.HIGHEST_PROTOCOL,
)

async def recycle_resource_loop(self):
pre_time_mark = time.time()

Expand Down
18 changes: 17 additions & 1 deletion lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightllm.server.multi_level_kv_cache.cpu_cache_client import CpuKvCacheClient
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.utils.log_utils import init_logger, log_time_ready
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.metrics.manager import MetricClient
from lightllm.common.basemodel.infer_lock import g_router_lock
Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(self, args: StartArgs):
if not self.args.enable_cpu_cache
else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False)
)

profiler_mode = args.enable_profiling
self.profiler = ProcessProfiler(mode=profiler_mode, name="lightllm-router") if profiler_mode else None
return

async def wait_to_model_ready(self):
Expand Down Expand Up @@ -504,16 +508,28 @@ def _multinode_tp_generate_new_batch(self):
raise e
return

async def _profiler_cmd(self, cmd_obj: ProfilerCmd):
self.profiler.cmd(cmd_obj)

cmd = ProfilerCmd(cmd=cmd_obj.cmd)
while not self.shm_reqs_io_buffer.is_empty():
await asyncio.sleep(0.02)
Comment on lines +515 to +516
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This while loop with asyncio.sleep(0.02) is a form of busy-waiting. While it might be acceptable for a non-production feature like profiling, a more efficient approach would be to use an event or condition variable to signal when the buffer is empty. This would avoid unnecessary polling and context switching.


self.shm_reqs_io_buffer.write_obj([cmd])
self.shm_reqs_io_buffer.set_ready()

async def _recv_new_reqs_and_schedule(self):
if not hasattr(self, "recv_max_count"):
self.recv_max_count = 64

try:
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(self.recv_max_count):
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
recv_req: Union[GroupReqIndexes, ProfilerCmd] = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
if isinstance(recv_req, GroupReqIndexes):
self._add_req(recv_req)
elif isinstance(recv_req, ProfilerCmd):
await self._profiler_cmd(recv_req)
else:
assert False, f"Error Req Inf {recv_req}"

Expand Down
11 changes: 11 additions & 0 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
from .multi_level_kv_cache import MultiLevelKvCacheModule
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd


class ModeBackend:
Expand Down Expand Up @@ -231,6 +232,10 @@ def init_model(self, kvargs):
if self.args.mtp_mode:
self.init_mtp_draft_model(kvargs)

prof_name = f"lightllm-model_backend-node{self.node_rank}_dev{get_current_device_id()}"
prof_mode = self.args.enable_profiling
self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name, use_multi_thread=True) if prof_mode else None

# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
# 可以降低 cpu overhead,大幅提升gpu得使用率。
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True)
Expand Down Expand Up @@ -343,6 +348,10 @@ def _try_read_new_reqs(self):
self._try_read_new_reqs_multinode_tp()
else:
self._try_read_new_reqs_normal()

# on each loop thread
if self.profiler is not None:
self.profiler.multi_thread_helper()
return

def _try_read_new_reqs_normal(self):
Expand Down Expand Up @@ -408,6 +417,8 @@ def _read_reqs_buffer_and_init_reqs(self):
if obj.req_id in g_infer_context.requests_mapping:
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
req.infer_aborted = True
elif isinstance(obj, ProfilerCmd):
self.profiler.cmd(obj)
else:
assert False, f"error type {type(obj)}"
if init_reqs:
Expand Down
15 changes: 13 additions & 2 deletions lightllm/server/visualserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pickle
import inspect
import setproctitle
from typing import List
from typing import List, Union
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with other files, let's use Optional for type hints of optional values. This also prepares for the change on line 62.

Suggested change
from typing import List, Union
from typing import List, Union, Optional

from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
from lightllm.server.core.objs import ShmReqManager, StartArgs

Expand All @@ -18,6 +18,7 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.profiler import ProcessProfiler, ProfilerCmd
from rpyc.utils.classic import obtain


Expand Down Expand Up @@ -59,6 +60,8 @@ def __init__(
self.visual_model_rpc_ports = visual_model_rpc_ports
self.send_batch_size = args.visual_send_batch_size
self.shm_req_manager = ShmReqManager()
prof_mode = args.enable_profiling
self.profiler = ProcessProfiler(prof_mode, name="lightllm-visual_server") if prof_mode else None

async def wait_to_model_ready(self):

Expand Down Expand Up @@ -185,9 +188,17 @@ async def loop_for_netio_req(self):
while True:
try:
for _ in range(self.visual_recv_max_count):
recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
recv_req: GroupReqIndexes | ProfilerCmd = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
if isinstance(recv_req, GroupReqIndexes):
self.waiting_reqs.append(recv_req)
elif isinstance(recv_req, ProfilerCmd):
self.profiler.cmd(recv_req)
tasks = []
for dp in range(self.vit_dp):
for tp in range(self.vit_tp):
task = asyncio.create_task(self.model_rpcs[dp][tp].profiler_cmd(recv_req))
tasks.append(task)
await asyncio.gather(*tasks)
else:
assert False, f"Error Req Inf {recv_req}"
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)
Expand Down
18 changes: 18 additions & 0 deletions lightllm/server/visualserver/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.profiler import ProcessProfiler


class VisualModelRpcServer(rpyc.Service):
Expand All @@ -43,6 +44,9 @@ def exposed_init_model(self, kvargs):
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.data_type = kvargs["data_type"]

prof_mode = get_env_start_args().enable_profiling
prof_name = f"lightllm-visual-vit_dp{self.dp_rank_id}_tp{self.tp_rank_id}"
self.profiler = ProcessProfiler(mode=prof_mode, name=prof_name) if prof_mode else None
init_vision_distributed_env(kvargs)
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)

Expand Down Expand Up @@ -116,6 +120,10 @@ def exposed_encode(self, images: List[ImageItem]):
self.cache_client.root.set_items_embed(ids_to_set)
return

def exposed_profiler_cmd(self, cmd_obj):
cmd_obj = obtain(cmd_obj)
self.profiler.cmd(cmd_obj)


class VisualModelRpcClient:
def __init__(self, model_rpc, vit_tp, rpc_server_process=None):
Expand All @@ -138,9 +146,11 @@ async def _func(*args, **kwargs):

self._init_model = async_wrap(self.model.init_model)
self._encode = async_wrap(self.model.encode)
self._profiler_cmd = async_wrap(self.model.profiler_cmd)
else:
self._init_model = self.model.exposed_init_model
self._encode = self.model.exposed_encode
self._profiler_cmd = self.model.exposed_profiler_cmd
return

async def init_model(self, kvargs):
Expand All @@ -158,6 +168,14 @@ async def encode(self, images: List[ImageItem]):
else:
return ans

async def profiler_cmd(self, cmd_obj):
ans: rpyc.AsyncResult = self._profiler_cmd(cmd_obj)
if self.use_rpc:
await ans
return
else:
return


def _init_env(port, device_id):
# 注册graceful 退出的处理
Expand Down
Loading