Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/prime_rl/inference/vllm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
engine_client,
base,
init_app_state,
models,
)
from vllm.entrypoints.openai.protocol import LoadLoRAAdapterRequest
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.logger import init_logger
from vllm.utils.argparse_utils import FlexibleArgumentParser
Expand All @@ -67,12 +69,32 @@ def chat_with_tokens(request: Request) -> OpenAIServingChatWithTokens | None:
async def update_weights(request: Request):
data = await request.json()
await engine_client(request).collective_rpc("update_weights", args=(data.get("weight_dir"),))
# Reset prefix cache to invalidate KV states computed with old weights
await engine_client(request).reset_prefix_cache()
return {"status": "ok"}


@router.post("/reload_weights")
async def reload_weights(request: Request):
await engine_client(request).collective_rpc("reload_weights")
# Reset prefix cache to invalidate KV states computed with old weights
await engine_client(request).reset_prefix_cache()
return {"status": "ok"}


@router.post("/load_lora_adapter")
async def load_lora_adapter(lora_request: LoadLoRAAdapterRequest, raw_request: Request):
"""Load a LoRA adapter and reset the prefix cache.
Wrapper around vLLM's /v1/load_lora_adapter that also resets the prefix cache
to invalidate KV states computed with old weights.
"""
handler = models(raw_request)
response = await handler.load_lora_adapter(lora_request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), status_code=response.error.code)
# Reset prefix cache to invalidate KV states computed with old weights
await engine_client(raw_request).reset_prefix_cache()
return {"status": "ok"}


Expand Down
8 changes: 7 additions & 1 deletion src/prime_rl/utils/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ async def update_weights(
Creates a NCCL_READY marker file before calling the update endpoint to signal
to the trainer that inference workers are about to enter the receive path.
This marker is only used in NCCL broadcast mode but is harmless in filesystem mode.

Note: The server-side /update_weights endpoint automatically resets the prefix cache
to invalidate any cached KV states computed with the old weights.
"""
logger = get_logger()

Expand Down Expand Up @@ -256,6 +259,9 @@ def _is_retryable_lora_error(exception: BaseException) -> bool:
async def load_lora_adapter(admin_clients: list[AsyncClient], lora_name: str, lora_path: Path) -> None:
"""Make a HTTP post request to the vLLM server to load a LoRA adapter.

Uses our wrapper endpoint that also resets the prefix cache to invalidate
KV states computed with old weights.

Retries with exponential backoff if the adapter files are not found,
which can happen due to NFS propagation delays.
"""
Expand All @@ -271,7 +277,7 @@ async def load_lora_adapter(admin_clients: list[AsyncClient], lora_name: str, lo
async def _load_lora_adapter(admin_client: AsyncClient) -> None:
logger.debug(f"Sending request to load LoRA adapter {lora_name} from {lora_path}")
response = await admin_client.post(
"/v1/load_lora_adapter",
"/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path_posix},
)
response.raise_for_status()
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/utils/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_load_lora_adapter_succeeds_on_first_attempt():
asyncio.run(load_lora_adapter([mock_client], "test-lora", Path("/test/path")))

mock_client.post.assert_called_once_with(
"/v1/load_lora_adapter",
"/load_lora_adapter",
json={"lora_name": "test-lora", "lora_path": "/test/path"},
)

Expand Down