Skip to content

[Misc] Enable vLLM to Dynamically Load LoRA from a Remote Server #10546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions docs/source/features/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,18 @@ curl http://localhost:8000/v1/completions \

## Dynamically serving LoRA Adapters

In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
to change models on-the-fly is needed.
In addition to serving LoRA adapters at server startup, the vLLM server supports dynamically configuring LoRA adapters at runtime through dedicated API endpoints and plugins. This feature can be particularly useful when the flexibility to change models on-the-fly is needed.

Note: Enabling this feature in production environments is risky as users may participate in model adapter management.

To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
To enable dynamic LoRA configuration, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
is set to `True`.

```bash
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
```

### Using API Endpoints
Loading a LoRA Adapter:

To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
Expand Down Expand Up @@ -153,6 +152,58 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
}'
```

### Using Plugins
Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adapters. LoRAResolver plugins enable you to load LoRA adapters from both local and remote sources such as local file system and S3. On every request, when there's a new model name that hasn't been loaded yet, the LoRAResolver will try to resolve and load the corresponding LoRA adapter.

You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds.

You can either install existing plugins or implement your own.

Steps to implement your own LoRAResolver plugin:
1. Implement the LoRAResolver interface.

Example of a simple S3 LoRAResolver implementation:

```python
import os
import s3fs
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver

class S3LoRAResolver(LoRAResolver):
def __init__(self):
self.s3 = s3fs.S3FileSystem()
self.s3_path_format = os.getenv("S3_PATH_TEMPLATE")
self.local_path_format = os.getenv("LOCAL_PATH_TEMPLATE")

async def resolve_lora(self, base_model_name, lora_name):
s3_path = self.s3_path_format.format(base_model_name=base_model_name, lora_name=lora_name)
local_path = self.local_path_format.format(base_model_name=base_model_name, lora_name=lora_name)

# Download the LoRA from S3 to the local path
await self.s3._get(
s3_path, local_path, recursive=True, maxdepth=1
)

lora_request = LoRARequest(
lora_name=lora_name,
lora_path=local_path,
lora_int_id=abs(hash(lora_name))
)
return lora_request
```

2. Register LoRAResolver plugin.

```python
from vllm.lora.resolver import LoRAResolverRegistry

s3_resolver = S3LoRAResolver()
LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver)
```

For more details, refer to the [vLLM's Plugins System](../design/plugin_system.md).

## New format for `--lora-modules`

In the previous version, users would provide LoRA modules via the following format, either as a key-value pair or in JSON format. For example:
Expand Down
209 changes: 209 additions & 0 deletions tests/entrypoints/openai/test_lora_resolvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# SPDX-License-Identifier: Apache-2.0

from contextlib import suppress
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Optional
from unittest.mock import MagicMock

import pytest

from vllm.config import MultiModalConfig
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
from vllm.transformers_utils.tokenizer import get_tokenizer

MODEL_NAME = "openai-community/gpt2"
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

MOCK_RESOLVER_NAME = "mock_test_resolver"


@dataclass
class MockHFConfig:
model_type: str = "any"


@dataclass
class MockModelConfig:
"""Minimal mock ModelConfig for testing."""
model: str = MODEL_NAME
tokenizer: str = MODEL_NAME
trust_remote_code: bool = False
tokenizer_mode: str = "auto"
max_model_len: int = 100
tokenizer_revision: Optional[str] = None
multimodal_config: MultiModalConfig = field(
default_factory=MultiModalConfig)
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
logits_processor_pattern: Optional[str] = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
encoder_config = None
generation_config: str = "auto"

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}


class MockLoRAResolver(LoRAResolver):

async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
if lora_name == "test-lora":
return LoRARequest(lora_name="test-lora",
lora_int_id=1,
lora_local_path="/fake/path/test-lora")
elif lora_name == "invalid-lora":
return LoRARequest(lora_name="invalid-lora",
lora_int_id=2,
lora_local_path="/fake/path/invalid-lora")
return None


@pytest.fixture(autouse=True)
def register_mock_resolver():
"""Fixture to register and unregister the mock LoRA resolver."""
resolver = MockLoRAResolver()
LoRAResolverRegistry.register_resolver(MOCK_RESOLVER_NAME, resolver)
yield
# Cleanup: remove the resolver after the test runs
if MOCK_RESOLVER_NAME in LoRAResolverRegistry.resolvers:
del LoRAResolverRegistry.resolvers[MOCK_RESOLVER_NAME]


@pytest.fixture
def mock_serving_setup():
"""Provides a mocked engine and serving completion instance."""
mock_engine = MagicMock(spec=MQLLMEngineClient)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False

def mock_add_lora_side_effect(lora_request: LoRARequest):
"""Simulate engine behavior when adding LoRAs."""
if lora_request.lora_name == "test-lora":
# Simulate successful addition
return
elif lora_request.lora_name == "invalid-lora":
# Simulate failure during addition (e.g. invalid format)
raise ValueError(f"Simulated failure adding LoRA: "
f"{lora_request.lora_name}")

mock_engine.add_lora.side_effect = mock_add_lora_side_effect
mock_engine.generate.reset_mock()
mock_engine.add_lora.reset_mock()

mock_model_config = MockModelConfig()
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)

serving_completion = OpenAIServingCompletion(mock_engine,
mock_model_config,
models,
request_logger=None)

return mock_engine, serving_completion


@pytest.mark.asyncio
async def test_serving_completion_with_lora_resolver(mock_serving_setup,
monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")

mock_engine, serving_completion = mock_serving_setup

lora_model_name = "test-lora"
req_found = CompletionRequest(
model=lora_model_name,
prompt="Generate with LoRA",
)

# Suppress potential errors during the mocked generate call,
# as we are primarily checking for add_lora and generate calls
with suppress(Exception):
await serving_completion.create_completion(req_found)

mock_engine.add_lora.assert_called_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name

mock_engine.generate.assert_called_once()
called_lora_request = mock_engine.generate.call_args[1]['lora_request']
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == lora_model_name


@pytest.mark.asyncio
async def test_serving_completion_resolver_not_found(mock_serving_setup,
monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")

mock_engine, serving_completion = mock_serving_setup

non_existent_model = "non-existent-lora-adapter"
req = CompletionRequest(
model=non_existent_model,
prompt="what is 1+1?",
)

response = await serving_completion.create_completion(req)

mock_engine.add_lora.assert_not_called()
mock_engine.generate.assert_not_called()

assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.NOT_FOUND.value
assert non_existent_model in response.message


@pytest.mark.asyncio
async def test_serving_completion_resolver_add_lora_fails(
mock_serving_setup, monkeypatch):
monkeypatch.setenv("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "true")

mock_engine, serving_completion = mock_serving_setup

invalid_model = "invalid-lora"
req = CompletionRequest(
model=invalid_model,
prompt="what is 1+1?",
)

response = await serving_completion.create_completion(req)

# Assert add_lora was called before the failure
mock_engine.add_lora.assert_called_once()
called_lora_request = mock_engine.add_lora.call_args[0][0]
assert isinstance(called_lora_request, LoRARequest)
assert called_lora_request.lora_name == invalid_model

# Assert generate was *not* called due to the failure
mock_engine.generate.assert_not_called()

# Assert the correct error response
assert isinstance(response, ErrorResponse)
assert response.code == HTTPStatus.BAD_REQUEST.value
assert invalid_model in response.message


@pytest.mark.asyncio
async def test_serving_completion_flag_not_set(mock_serving_setup):
mock_engine, serving_completion = mock_serving_setup

lora_model_name = "test-lora"
req_found = CompletionRequest(
model=lora_model_name,
prompt="Generate with LoRA",
)

await serving_completion.create_completion(req_found)

mock_engine.add_lora.assert_not_called()
mock_engine.generate.assert_not_called()
74 changes: 74 additions & 0 deletions tests/lora/test_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import pytest

from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry


class DummyLoRAResolver(LoRAResolver):
"""A dummy LoRA resolver for testing."""

async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
if lora_name == "test_lora":
return LoRARequest(
lora_name=lora_name,
lora_path=f"/dummy/path/{base_model_name}/{lora_name}",
lora_int_id=abs(hash(lora_name)))
return None


def test_resolver_registry_registration():
"""Test basic resolver registration functionality."""
registry = LoRAResolverRegistry
resolver = DummyLoRAResolver()

# Register a new resolver
registry.register_resolver("dummy", resolver)
assert "dummy" in registry.get_supported_resolvers()

# Get registered resolver
retrieved_resolver = registry.get_resolver("dummy")
assert retrieved_resolver is resolver


def test_resolver_registry_duplicate_registration():
"""Test registering a resolver with an existing name."""
registry = LoRAResolverRegistry
resolver1 = DummyLoRAResolver()
resolver2 = DummyLoRAResolver()

registry.register_resolver("dummy", resolver1)
registry.register_resolver("dummy", resolver2)

assert registry.get_resolver("dummy") is resolver2


def test_resolver_registry_unknown_resolver():
"""Test getting a non-existent resolver."""
registry = LoRAResolverRegistry

with pytest.raises(KeyError, match="not found"):
registry.get_resolver("unknown_resolver")


@pytest.mark.asyncio
async def test_dummy_resolver_resolve():
"""Test the dummy resolver's resolve functionality."""
dummy_resolver = DummyLoRAResolver()
base_model_name = "base_model_test"
lora_name = "test_lora"

# Test successful resolution
result = await dummy_resolver.resolve_lora(base_model_name, lora_name)
assert isinstance(result, LoRARequest)
assert result.lora_name == lora_name
assert result.lora_path == f"/dummy/path/{base_model_name}/{lora_name}"

# Test failed resolution
result = await dummy_resolver.resolve_lora(base_model_name,
"nonexistent_lora")
assert result is None
Loading