Skip to content

Commit

Permalink
[Core] Faster startup for LoRA enabled models (vllm-project#4634)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored May 8, 2024
1 parent 5510cf0 commit ad932a2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 18 deletions.
10 changes: 10 additions & 0 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def __init__(
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras

def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)

@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
Expand Down
26 changes: 22 additions & 4 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod, abstractproperty
from typing import Any, Dict, List, Set, Type
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Set, Type, Union

import torch

Expand All @@ -25,6 +26,17 @@ def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
self.device = device
self.lora_config = lora_config

# If False, do not cache. If None, cache is empty.
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False

@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False

@abstractproperty
def is_enabled(self) -> bool:
...
Expand Down Expand Up @@ -174,9 +186,15 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_loras():
return False
return self._lora_manager.add_lora(
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
rank, self.embedding_modules))
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id)
else:
dummy_lora = self._lora_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._lora_manager.add_lora(dummy_lora)

def add_lora(self, lora_request: LoRARequest) -> bool:
if lora_request.lora_int_id in self.list_loras():
Expand Down
29 changes: 15 additions & 14 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,20 +835,21 @@ def profile_run(self) -> None:
dummy_lora_requests = []
dummy_lora_requests_per_seq = []
if self.lora_config:
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]
with self.lora_manager.dummy_lora_cache():
for idx in range(self.lora_config.max_loras):
lora_id = idx + 1
dummy_lora_request = LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_local_path="/not/a/real/path",
)
self.lora_manager.add_dummy_lora(dummy_lora_request,
rank=LORA_WARMUP_RANK)
dummy_lora_requests.append(dummy_lora_request)
dummy_lora_requests_per_seq = [
dummy_lora_requests[idx % len(dummy_lora_requests)]
for idx in range(max_num_seqs)
]

# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
Expand Down

0 comments on commit ad932a2

Please sign in to comment.