Skip to content

Commit

Permalink
[LoRA] Add support for pinning lora adapters in the LRU cache (vllm-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithkrn authored Jun 21, 2024
1 parent 8e178a3 commit ecc8149
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 5 deletions.
64 changes: 64 additions & 0 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.activate_lora(1)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.deactivate_lora(2)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.pin_lora(3)
assert manager.pin_lora(1)
with pytest.raises(RuntimeError):
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
with pytest.raises(RuntimeError):
assert manager.activate_lora(2)

assert manager.deactivate_lora(3)
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.remove_lora(3)
with pytest.raises(ValueError):
assert manager.pin_lora(3)


def test_lru_lora_model_manager(dist_init, dummy_model):
Expand Down Expand Up @@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert set(manager.list_loras()) == set()
assert all(x is None for x in manager.lora_index_to_id)

# pinning
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
with pytest.raises(ValueError):
assert manager.pin_lora(1)
assert manager.pin_lora(3)
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)

assert set(manager.list_loras()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4

assert manager.add_lora(model_lora1)
assert manager.pin_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)

assert set(manager.list_loras()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2

assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {1}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] is None

with pytest.raises(RuntimeError):
assert manager.remove_oldest_lora()

assert set(manager.list_loras()) == {1}


def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,9 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> Set[int]:
return self.model_executor.list_loras()

def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)

def check_health(self) -> None:
self.model_executor.check_health()

Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

Expand Down
7 changes: 7 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ def remove_lora(self, lora_id: int) -> bool:
lora_id=lora_id,
)

def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"pin_lora",
lora_id=lora_id,
)

def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")

Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError # type: ignore

@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

Expand Down
26 changes: 26 additions & 0 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,12 @@ def remove_lora(self, lora_id: int) -> bool:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))

def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning") # type: ignore

# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
Expand Down Expand Up @@ -777,6 +783,26 @@ def remove_oldest_lora(self) -> bool:
return True
return False

def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True

def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_loras.pin(lora_id)
except ValueError as err:
raise ValueError("Pinning failed. "
f"LoRA {lora_id} is not registered.") from err

def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_loras:
# move lora to gpu if not already active
self.activate_lora(lora_id)

self._active_loras.pin(lora_id)


def create_lora_manager(
model: nn.Module,
Expand Down
3 changes: 3 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self._lora_manager.pin_lora(lora_id)

def remove_all_loras(self):
self._lora_manager.remove_all_loras()

Expand Down
43 changes: 38 additions & 5 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union)

import numpy as np
Expand Down Expand Up @@ -44,6 +44,13 @@
T = TypeVar("T")


class _Sentinel:
...


ALL_PINNED_SENTINEL = _Sentinel()


class Device(enum.Enum):
GPU = enum.auto()
CPU = enum.auto()
Expand All @@ -67,6 +74,7 @@ class LRUCache(Generic[T]):

def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
self.capacity = capacity

def __contains__(self, key: Hashable) -> bool:
Expand Down Expand Up @@ -102,14 +110,36 @@ def put(self, key: Hashable, value: T) -> None:
self.cache.move_to_end(key)
self._remove_old_if_needed()

def pin(self, key: Hashable) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if key not in self.cache:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)

def _unpin(self, key: Hashable) -> None:
self.pinned_items.remove(key)

def _on_remove(self, key: Hashable, value: Optional[T]):
pass

def remove_oldest(self):
def remove_oldest(self, remove_pinned=False):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)

if not remove_pinned:
# pop the oldest item in the cache that is not pinned
lru_key = next(
(key for key in self.cache if key not in self.pinned_items),
ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)

def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
Expand All @@ -120,13 +150,16 @@ def pop(self,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value

def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.remove_oldest(remove_pinned=True)
self.cache.clear()


Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,11 @@ def remove_lora(self, lora_id: int) -> bool:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()

Expand Down
8 changes: 8 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
Expand All @@ -86,6 +90,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")

def pin_lora(self, lora_id: int) -> bool:
return ValueError(
f"{type(self)} does not support LoRA") # type: ignore

def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")

Expand Down

0 comments on commit ecc8149

Please sign in to comment.