Skip to content

Commit d96d36a

Browse files
committed
add gc interface and bugfix
1 parent b0f67cf commit d96d36a

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

lmdeploy/pytorch/disagg/conn/proxy_conn.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(self):
7474
# put migrating session to `self.migration_session_shelf` for increasing fault tolerance
7575
# if a session is finished, then pop it from `self.migration_session_shelf`
7676
# if a decode instance is disconnected, then gc all blocks of these sessions in prefill instance.
77-
self.migration_session_shelf: Dict[Tuple[str, str], Set[int]] = defaultdict(set)
77+
self.migration_session_shelf: Dict[str, Set[int]] = defaultdict(set)
7878

7979
# conn_perform handler queue
8080
self.waiting_conn: asyncio.Queue[Tuple[PDConnectionMessage, asyncio.Event]] = (asyncio.Queue())
@@ -93,17 +93,15 @@ def __init__(self):
9393

9494
def reg_instance(self, role: EngineRole, endpoint: str):
9595
if role == EngineRole.Prefill:
96-
logger.error('????????????????')
9796
self.prefill_endpoints.add(endpoint)
9897
elif role == EngineRole.Decode:
99-
logger.error('????????????????')
10098
self.decode_endpoints.add(endpoint)
10199
else:
102100
raise ValueError(f'Unsupported role: {role}')
103101

104102
def dereg_instance(self, endpoint: str):
105103
if endpoint in self.prefill_endpoints:
106-
self.prefill_endpoints.pop(endpoint)
104+
self.prefill_endpoints.remove(endpoint)
107105
elif endpoint in self.decode_endpoints:
108106
dropped_key = []
109107
for conn_key in self.pool.keys():
@@ -112,7 +110,13 @@ def dereg_instance(self, endpoint: str):
112110
for k in dropped_key:
113111
self.drop(k)
114112
# TODO(JimyMa): handle side-effect by kvcache migration
115-
self.decode_endpoints.pop(endpoint)
113+
self.decode_endpoints.remove(endpoint)
114+
115+
def shelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int):
116+
self.migration_session_shelf[conn_key].add(session_id)
117+
118+
def unshelf_prefill_session(self, conn_key: Tuple[str, str], session_id: int):
119+
self.migration_session_shelf[conn_key].remove(session_id)
116120

117121
async def connect(self, conn_req: PDConnectionMessage):
118122

lmdeploy/serve/proxy/proxy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,12 @@ async def connection_warmup():
507507
return JSONResponse({'SUCCESS': True})
508508

509509

510+
@app.post('/distserve/gc')
511+
async def cache_block_gc_to_be_migrated():
512+
# TODO (JimyMa): add garbage collection of to be migrated request
513+
raise NotImplementedError
514+
515+
510516
@app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)])
511517
async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None):
512518
"""Completion API similar to OpenAI's API.
@@ -625,17 +631,21 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
625631
).model_dump(mode='json')
626632

627633
start = node_manager.pre_call(d_url)
634+
node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])
628635
if request.stream is True:
629636
response = node_manager.stream_generate(request_dict, d_url, '/v1/chat/completions')
630637
background_task = node_manager.create_background_tasks(d_url, start)
638+
node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id'])
631639
return StreamingResponse(response, background=background_task)
632640
else:
633641
try:
634642
response = await node_manager.generate(request_dict, d_url, '/v1/chat/completions')
635643
node_manager.post_call(d_url, start)
636644
resp = JSONResponse(json.loads(response))
637645
finally:
646+
node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id'])
638647
return resp
648+
639649
else:
640650
raise ValueError(f'No serving strategy named {node_manager.serving_strategy}')
641651

@@ -737,15 +747,18 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
737747
remote_block_ids=prefill_info['cache_block_ids'],
738748
remote_token_id=prefill_info['remote_token_ids'][-1],
739749
).model_dump(mode='json')
750+
node_manager.pd_connection_pool.shelf_prefill_session((p_url, d_url), prefill_info['id'])
740751

741752
start = node_manager.pre_call(d_url)
742753
if request.stream is True:
743754
response = node_manager.stream_generate(request_dict, d_url, '/v1/completions')
744755
background_task = node_manager.create_background_tasks(d_url, start)
756+
node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id'])
745757
return StreamingResponse(response, background=background_task)
746758
else:
747759
response = await node_manager.generate(request_dict, d_url, '/v1/completions')
748760
node_manager.post_call(d_url, start)
761+
node_manager.pd_connection_pool.unshelf_prefill_session((p_url, d_url), prefill_info['id'])
749762
return JSONResponse(json.loads(response))
750763
else:
751764
raise ValueError(f'No serving strategy named {node_manager.serving_strategy}')

0 commit comments

Comments
 (0)