Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] Add support for pinning lora adapters in the LRU cache #5603

Merged
merged 7 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 3
assert manager.activate_lora(1)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.deactivate_lora(2)
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 1
assert manager.activate_lora(3)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
assert manager.pin_lora(3)
assert manager.pin_lora(1)
with pytest.raises(RuntimeError):
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 3
assert manager.lora_index_to_id[1] == 1
with pytest.raises(RuntimeError):
assert manager.activate_lora(2)

assert manager.deactivate_lora(3)
assert manager.pin_lora(2)
assert manager.lora_index_to_id[0] == 2
assert manager.lora_index_to_id[1] == 1
assert manager.remove_lora(3)
with pytest.raises(ValueError):
assert manager.pin_lora(3)


def test_lru_lora_model_manager(dist_init, dummy_model):
Expand Down Expand Up @@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
assert set(manager.list_loras()) == set()
assert all(x is None for x in manager.lora_index_to_id)

# pinning
assert manager.add_lora(model_lora3)
assert manager.activate_lora(3)
assert manager.add_lora(model_lora4)
assert manager.activate_lora(4)
assert set(manager.list_loras()) == {3, 4}
with pytest.raises(ValueError):
assert manager.pin_lora(1)
assert manager.pin_lora(3)
# Remove manually
assert manager.remove_lora(3)
assert not manager.remove_lora(3)

assert set(manager.list_loras()) == {4}
assert manager.lora_index_to_id[0] is None
assert manager.lora_index_to_id[1] == 4

assert manager.add_lora(model_lora1)
assert manager.pin_lora(1)
assert manager.add_lora(model_lora2)
assert manager.activate_lora(2)

assert set(manager.list_loras()) == {1, 2}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] == 2

assert manager.remove_oldest_lora()
assert set(manager.list_loras()) == {1}
assert manager.lora_index_to_id[0] == 1
assert manager.lora_index_to_id[1] is None

with pytest.raises(RuntimeError):
assert manager.remove_oldest_lora()

assert set(manager.list_loras()) == {1}


def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files):
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,5 +976,8 @@ def remove_lora(self, lora_id: int) -> bool:
def list_loras(self) -> Set[int]:
return self.model_executor.list_loras()

def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)

def check_health(self) -> None:
self.model_executor.check_health()
3 changes: 3 additions & 0 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

Expand Down
7 changes: 7 additions & 0 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ def remove_lora(self, lora_id: int) -> bool:
lora_id=lora_id,
)

def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self._run_workers(
"pin_lora",
lora_id=lora_id,
)

def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")

Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError # type: ignore

@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/neuron_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.driver_worker.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()

Expand Down
26 changes: 26 additions & 0 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,12 @@ def remove_lora(self, lora_id: int) -> bool:
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
return bool(self._registered_loras.pop(lora_id, None))

def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in LoRAModelManager."
"Use LRUCacheLoRAModelManager for pinning") # type: ignore

# TODO see if this can be vectorized
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
(base_indices, sampler_indices, sampler_indices_padded,
Expand Down Expand Up @@ -777,6 +783,26 @@ def remove_oldest_lora(self) -> bool:
return True
return False

def pin_lora(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True

def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_loras.pin(lora_id)
except ValueError as err:
raise ValueError("Pinning failed. "
f"LoRA {lora_id} is not registered.") from err

def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_loras:
# move lora to gpu if not already active
self.activate_lora(lora_id)

self._active_loras.pin(lora_id)


def create_lora_manager(
model: nn.Module,
Expand Down
3 changes: 3 additions & 0 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self._lora_manager.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self._lora_manager.pin_lora(lora_id)

def remove_all_loras(self):
self._lora_manager.remove_all_loras()

Expand Down
43 changes: 38 additions & 5 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union)

import numpy as np
Expand Down Expand Up @@ -43,6 +43,13 @@
T = TypeVar("T")


class _Sentinel:
...


ALL_PINNED_SENTINEL = _Sentinel()


class Device(enum.Enum):
GPU = enum.auto()
CPU = enum.auto()
Expand All @@ -66,6 +73,7 @@ class LRUCache(Generic[T]):

def __init__(self, capacity: int):
self.cache: OrderedDict[Hashable, T] = OrderedDict()
self.pinned_items: Set[Hashable] = set()
self.capacity = capacity

def __contains__(self, key: Hashable) -> bool:
Expand Down Expand Up @@ -101,14 +109,36 @@ def put(self, key: Hashable, value: T) -> None:
self.cache.move_to_end(key)
self._remove_old_if_needed()

def pin(self, key: Hashable) -> None:
rohithkrn marked this conversation as resolved.
Show resolved Hide resolved
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if key not in self.cache:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)

def _unpin(self, key: Hashable) -> None:
self.pinned_items.remove(key)

def _on_remove(self, key: Hashable, value: Optional[T]):
pass

def remove_oldest(self):
def remove_oldest(self, remove_pinned=False):
if not self.cache:
return
key, value = self.cache.popitem(last=False)
self._on_remove(key, value)

if not remove_pinned:
# pop the oldest item in the cache that is not pinned
lru_key = next(
(key for key in self.cache if key not in self.pinned_items),
ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key)

def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
Expand All @@ -119,13 +149,16 @@ def pop(self,
default_value: Optional[T] = None) -> Optional[T]:
run_on_remove = key in self.cache
value: Optional[T] = self.cache.pop(key, default_value)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value

def clear(self):
while len(self.cache) > 0:
self.remove_oldest()
self.remove_oldest(remove_pinned=True)
self.cache.clear()


Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,11 @@ def remove_lora(self, lora_id: int) -> bool:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
return self.model_runner.remove_lora(lora_id)

def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)

def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()

Expand Down
8 changes: 8 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
Expand All @@ -86,6 +90,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")

def pin_lora(self, lora_id: int) -> bool:
return ValueError(
f"{type(self)} does not support LoRA") # type: ignore

def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")

Expand Down
Loading