Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Support load and unload LoRA in api server #6566

Merged
merged 11 commits into from
Sep 6, 2024
1 change: 0 additions & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,5 @@ pydantic >= 2.8
torch
py-cpuinfo
transformers
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
mistral_common >= 1.3.4
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
52 changes: 52 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,55 @@ The following is an example request
"max_tokens": 7,
"temperature": 0
}' | jq


Dynamically serving LoRA Adapters
---------------------------------

In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
to change models on-the-fly is needed.

Note: Enabling this feature in production environments is risky as user may participate model adapter management.

To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.

.. code-block:: bash

export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True


Loading a LoRA Adapter:

To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter.

Example request to load a LoRA adapter:

.. code-block:: bash

curl -X POST http://localhost:8000/v1/load_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter",
"lora_path": "/path/to/sql-lora-adapter"
}'

Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter
cannot be found or loaded, an appropriate error message will be returned.

Unloading a LoRA Adapter:

To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint
with the name or ID of the adapter to be unloaded.

Example request to unload a LoRA adapter:

.. code-block:: bash

curl -X POST http://localhost:8000/v1/unload_lora_adapter \
-H "Content-Type: application/json" \
-d '{
"lora_name": "sql_adapter"
}'
2 changes: 1 addition & 1 deletion tests/entrypoints/llm/test_generate_multiple_loras.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def zephyr_lora_files():
@pytest.mark.skip_global_cleanup
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
lora_request = [
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files)
for idx in range(len(PROMPTS))
]
# Multiple SamplingParams should be matched with each prompt
Expand Down
107 changes: 107 additions & 0 deletions tests/entrypoints/openai/test_serving_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from http import HTTPStatus
from unittest.mock import MagicMock

import pytest

from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.serving_engine import OpenAIServing

MODEL_NAME = "meta-llama/Llama-2-7b"
LORA_LOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' added successfully.")
LORA_UNLOADING_SUCCESS_MESSAGE = (
"Success: LoRA adapter '{lora_name}' removed successfully.")


async def _async_serving_engine_init():
mock_engine_client = MagicMock(spec=AsyncEngineClient)
mock_model_config = MagicMock(spec=ModelConfig)
# Set the max_model_len attribute to avoid missing attribute
mock_model_config.max_model_len = 2048

serving_engine = OpenAIServing(mock_engine_client,
mock_model_config,
served_model_names=[MODEL_NAME],
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_engine


@pytest.mark.asyncio
async def test_load_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter",
lora_path="/path/to/adapter2")
response = await serving_engine.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
assert len(serving_engine.lora_requests) == 1
assert serving_engine.lora_requests[0].lora_name == "adapter"


@pytest.mark.asyncio
async def test_load_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
response = await serving_engine.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_load_lora_adapter_duplicate():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 1

request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
assert len(serving_engine.lora_requests) == 1


@pytest.mark.asyncio
async def test_unload_lora_adapter_success():
serving_engine = await _async_serving_engine_init()
request = LoadLoraAdapterRequest(lora_name="adapter1",
lora_path="/path/to/adapter1")
response = await serving_engine.load_lora_adapter(request)
assert len(serving_engine.lora_requests) == 1

request = UnloadLoraAdapterRequest(lora_name="adapter1")
response = await serving_engine.unload_lora_adapter(request)
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
lora_name='adapter1')
assert len(serving_engine.lora_requests) == 0


@pytest.mark.asyncio
async def test_unload_lora_adapter_missing_fields():
serving_engine = await _async_serving_engine_init()
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
response = await serving_engine.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST


@pytest.mark.asyncio
async def test_unload_lora_adapter_not_found():
serving_engine = await _async_serving_engine_init()
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
response = await serving_engine.unload_lora_adapter(request)
assert isinstance(response, ErrorResponse)
assert response.type == "InvalidUserInput"
assert response.code == HTTPStatus.BAD_REQUEST
40 changes: 38 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
LoadLoraAdapterRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
TokenizeResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
Expand Down Expand Up @@ -340,6 +342,40 @@ async def stop_profile():
return Response(status_code=200)


if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"Lora dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!")

@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest):
response = await openai_serving_chat.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await openai_serving_completion.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200, content=response)

@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
response = await openai_serving_chat.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

response = await openai_serving_completion.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)

return Response(status_code=200, content=response)


Jeffwan marked this conversation as resolved.
Show resolved Hide resolved
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,13 @@ class DetokenizeRequest(OpenAIBaseModel):

class DetokenizeResponse(OpenAIBaseModel):
prompt: str


class LoadLoraAdapterRequest(BaseModel):
lora_name: str
lora_path: str


class UnloadLoraAdapterRequest(BaseModel):
lora_name: str
lora_int_id: Optional[int] = Field(default=None)
79 changes: 78 additions & 1 deletion vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest)
TokenizeRequest,
UnloadLoraAdapterRequest)
# yapf: enable
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
Expand All @@ -32,6 +34,7 @@
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter

logger = init_logger(__name__)

Expand Down Expand Up @@ -78,6 +81,7 @@ def __init__(

self.served_model_names = served_model_names

self.lora_id_counter = AtomicCounter(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jeffwan this is not required here. asyncio operations all happen in the same thread. Can change this to be a simple int field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill yeah, that makes sense. Let me file a follow up PR to improve it.

self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
Expand Down Expand Up @@ -403,3 +407,76 @@ def _get_decoded_token(logprob: Logprob,
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)

async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return self.create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

return None

async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return self.create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)

return None

async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret

lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."

async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret

lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
VLLM_PLUGINS: Optional[List[str]] = None
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -409,6 +410,12 @@ def get_default_config_root():
# If set, vLLM will use Triton implementations of AWQ.
"VLLM_USE_TRITON_AWQ":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),

# If set, allow loading or unloading lora adapters in runtime,
"VLLM_ALLOW_RUNTIME_LORA_UPDATING":
lambda:
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
("1", "true")),
Jeffwan marked this conversation as resolved.
Show resolved Hide resolved
}

# end-env-vars-definition
Expand Down
Loading
Loading