diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 51a56b121ae2c..2133bce14957b 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): assert manager.activate_lora(3) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 + assert manager.pin_lora(2) + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 3 + assert manager.activate_lora(1) + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 1 + assert manager.deactivate_lora(2) + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 1 + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 1 + assert manager.pin_lora(3) + assert manager.pin_lora(1) + with pytest.raises(RuntimeError): + assert manager.pin_lora(2) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 1 + with pytest.raises(RuntimeError): + assert manager.activate_lora(2) + + assert manager.deactivate_lora(3) + assert manager.pin_lora(2) + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 1 + assert manager.remove_lora(3) + with pytest.raises(ValueError): + assert manager.pin_lora(3) def test_lru_lora_model_manager(dist_init, dummy_model): @@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert set(manager.list_loras()) == set() assert all(x is None for x in manager.lora_index_to_id) + # pinning + assert manager.add_lora(model_lora3) + assert manager.activate_lora(3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(4) + assert set(manager.list_loras()) == {3, 4} + with pytest.raises(ValueError): + assert manager.pin_lora(1) + assert manager.pin_lora(3) + # Remove manually + assert manager.remove_lora(3) + assert not manager.remove_lora(3) + + assert set(manager.list_loras()) == {4} + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 4 + + assert manager.add_lora(model_lora1) + assert manager.pin_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {1, 2} + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == {1} + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] is None + + with pytest.raises(RuntimeError): + assert manager.remove_oldest_lora() + + assert set(manager.list_loras()) == {1} + def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, sql_lora_files): diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 75d417f525e3a..f7eae257fdd16 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1009,6 +1009,9 @@ def remove_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_executor.list_loras() + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) + def check_health(self) -> None: self.model_executor.check_health() diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index a2212459f034e..6137cecd881d0 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -84,6 +84,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index f7c608af1ad39..235b5bc47021d 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -100,6 +100,13 @@ def remove_lora(self, lora_id: int) -> bool: lora_id=lora_id, ) + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "pin_lora", + lora_id=lora_id, + ) + def list_loras(self) -> Set[int]: return self._run_workers("list_loras") diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 4d01939c2e38b..7c2520b5a64f5 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -86,6 +86,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError + @abstractmethod + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError # type: ignore + @abstractmethod def list_loras(self) -> Set[int]: raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3ad201f4757ec..0a654200ed796 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -99,6 +99,10 @@ def remove_lora(self, lora_id: int) -> bool: assert lora_id > 0, "lora_id must be greater than 0." return self.driver_worker.remove_lora(lora_id) + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.pin_lora(lora_id) + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index e7f0e887921b7..c5e2fb0f67736 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -65,6 +65,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3e82856866d85..afb9ba4550671 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -525,6 +525,12 @@ def remove_lora(self, lora_id: int) -> bool: self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) return bool(self._registered_loras.pop(lora_id, None)) + def pin_lora(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in LoRAModelManager." + "Use LRUCacheLoRAModelManager for pinning") # type: ignore + # TODO see if this can be vectorized def _set_lora_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, @@ -777,6 +783,26 @@ def remove_oldest_lora(self) -> bool: return True return False + def pin_lora(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + self._pin_lora_in_cpu_cache(lora_id) + self._pin_lora_in_gpu_cache(lora_id) + return True + + def _pin_lora_in_cpu_cache(self, lora_id: int): + try: + self._registered_loras.pin(lora_id) + except ValueError as err: + raise ValueError("Pinning failed. " + f"LoRA {lora_id} is not registered.") from err + + def _pin_lora_in_gpu_cache(self, lora_id: int): + if lora_id not in self._active_loras: + # move lora to gpu if not already active + self.activate_lora(lora_id) + + self._active_loras.pin(lora_id) + def create_lora_manager( model: nn.Module, diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 498b2b9ddb18a..ca4903c23bcaa 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -221,6 +221,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self._lora_manager.remove_lora(lora_id) + def pin_lora(self, lora_id: int) -> bool: + return self._lora_manager.pin_lora(lora_id) + def remove_all_loras(self): self._lora_manager.remove_all_loras() diff --git a/vllm/utils.py b/vllm/utils.py index 27a7b1042d88f..ce5c377eff2d4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -15,7 +15,7 @@ from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, - Hashable, List, Optional, OrderedDict, Tuple, TypeVar, + Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union) import numpy as np @@ -44,6 +44,13 @@ T = TypeVar("T") +class _Sentinel: + ... + + +ALL_PINNED_SENTINEL = _Sentinel() + + class Device(enum.Enum): GPU = enum.auto() CPU = enum.auto() @@ -67,6 +74,7 @@ class LRUCache(Generic[T]): def __init__(self, capacity: int): self.cache: OrderedDict[Hashable, T] = OrderedDict() + self.pinned_items: Set[Hashable] = set() self.capacity = capacity def __contains__(self, key: Hashable) -> bool: @@ -102,14 +110,36 @@ def put(self, key: Hashable, value: T) -> None: self.cache.move_to_end(key) self._remove_old_if_needed() + def pin(self, key: Hashable) -> None: + """ + Pins a key in the cache preventing it from being + evicted in the LRU order. + """ + if key not in self.cache: + raise ValueError(f"Cannot pin key: {key} not in cache.") + self.pinned_items.add(key) + + def _unpin(self, key: Hashable) -> None: + self.pinned_items.remove(key) + def _on_remove(self, key: Hashable, value: Optional[T]): pass - def remove_oldest(self): + def remove_oldest(self, remove_pinned=False): if not self.cache: return - key, value = self.cache.popitem(last=False) - self._on_remove(key, value) + + if not remove_pinned: + # pop the oldest item in the cache that is not pinned + lru_key = next( + (key for key in self.cache if key not in self.pinned_items), + ALL_PINNED_SENTINEL) + if lru_key is ALL_PINNED_SENTINEL: + raise RuntimeError("All items are pinned, " + "cannot remove oldest from the cache.") + else: + lru_key = next(iter(self.cache)) + self.pop(lru_key) def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: @@ -120,13 +150,16 @@ def pop(self, default_value: Optional[T] = None) -> Optional[T]: run_on_remove = key in self.cache value: Optional[T] = self.cache.pop(key, default_value) + # remove from pinned items + if key in self.pinned_items: + self._unpin(key) if run_on_remove: self._on_remove(key, value) return value def clear(self): while len(self.cache) > 0: - self.remove_oldest() + self.remove_oldest(remove_pinned=True) self.cache.clear() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e24835a1ea7fb..a321eafce1a2f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -878,6 +878,11 @@ def remove_lora(self, lora_id: int) -> bool: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_lora(lora_id) + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_lora(lora_id) + def list_loras(self) -> Set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index e334ffbb755bb..c60764ef1bed8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -333,6 +333,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: return self.model_runner.remove_lora(lora_id) + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + def list_loras(self) -> Set[int]: return self.model_runner.list_loras() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 3d52fd71ec4b8..dc09718de4a32 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -70,6 +70,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError + @abstractmethod + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError + @abstractmethod def list_loras(self) -> Set[int]: raise NotImplementedError @@ -86,6 +90,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: raise ValueError(f"{type(self)} does not support LoRA") + def pin_lora(self, lora_id: int) -> bool: + return ValueError( + f"{type(self)} does not support LoRA") # type: ignore + def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA")