Skip to content

Commit f5dda63

Browse files
authored
[LoRA] Add support for pinning lora adapters in the LRU cache (#5603)
1 parent 7187507 commit f5dda63

13 files changed

+171
-5
lines changed

tests/lora/test_lora_manager.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,34 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model):
209209
assert manager.activate_lora(3)
210210
assert manager.lora_index_to_id[0] == 2
211211
assert manager.lora_index_to_id[1] == 3
212+
assert manager.pin_lora(2)
213+
assert manager.lora_index_to_id[0] == 2
214+
assert manager.lora_index_to_id[1] == 3
215+
assert manager.activate_lora(1)
216+
assert manager.lora_index_to_id[0] == 2
217+
assert manager.lora_index_to_id[1] == 1
218+
assert manager.deactivate_lora(2)
219+
assert manager.lora_index_to_id[0] is None
220+
assert manager.lora_index_to_id[1] == 1
221+
assert manager.activate_lora(3)
222+
assert manager.lora_index_to_id[0] == 3
223+
assert manager.lora_index_to_id[1] == 1
224+
assert manager.pin_lora(3)
225+
assert manager.pin_lora(1)
226+
with pytest.raises(RuntimeError):
227+
assert manager.pin_lora(2)
228+
assert manager.lora_index_to_id[0] == 3
229+
assert manager.lora_index_to_id[1] == 1
230+
with pytest.raises(RuntimeError):
231+
assert manager.activate_lora(2)
232+
233+
assert manager.deactivate_lora(3)
234+
assert manager.pin_lora(2)
235+
assert manager.lora_index_to_id[0] == 2
236+
assert manager.lora_index_to_id[1] == 1
237+
assert manager.remove_lora(3)
238+
with pytest.raises(ValueError):
239+
assert manager.pin_lora(3)
212240

213241

214242
def test_lru_lora_model_manager(dist_init, dummy_model):
@@ -288,6 +316,42 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
288316
assert set(manager.list_loras()) == set()
289317
assert all(x is None for x in manager.lora_index_to_id)
290318

319+
# pinning
320+
assert manager.add_lora(model_lora3)
321+
assert manager.activate_lora(3)
322+
assert manager.add_lora(model_lora4)
323+
assert manager.activate_lora(4)
324+
assert set(manager.list_loras()) == {3, 4}
325+
with pytest.raises(ValueError):
326+
assert manager.pin_lora(1)
327+
assert manager.pin_lora(3)
328+
# Remove manually
329+
assert manager.remove_lora(3)
330+
assert not manager.remove_lora(3)
331+
332+
assert set(manager.list_loras()) == {4}
333+
assert manager.lora_index_to_id[0] is None
334+
assert manager.lora_index_to_id[1] == 4
335+
336+
assert manager.add_lora(model_lora1)
337+
assert manager.pin_lora(1)
338+
assert manager.add_lora(model_lora2)
339+
assert manager.activate_lora(2)
340+
341+
assert set(manager.list_loras()) == {1, 2}
342+
assert manager.lora_index_to_id[0] == 1
343+
assert manager.lora_index_to_id[1] == 2
344+
345+
assert manager.remove_oldest_lora()
346+
assert set(manager.list_loras()) == {1}
347+
assert manager.lora_index_to_id[0] == 1
348+
assert manager.lora_index_to_id[1] is None
349+
350+
with pytest.raises(RuntimeError):
351+
assert manager.remove_oldest_lora()
352+
353+
assert set(manager.list_loras()) == {1}
354+
291355

292356
def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
293357
sql_lora_files):

vllm/engine/llm_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,9 @@ def remove_lora(self, lora_id: int) -> bool:
10091009
def list_loras(self) -> Set[int]:
10101010
return self.model_executor.list_loras()
10111011

1012+
def pin_lora(self, lora_id: int) -> bool:
1013+
return self.model_executor.pin_lora(lora_id)
1014+
10121015
def check_health(self) -> None:
10131016
self.model_executor.check_health()
10141017

vllm/executor/cpu_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
8484
def remove_lora(self, lora_id: int) -> bool:
8585
return self.driver_worker.remove_lora(lora_id)
8686

87+
def pin_lora(self, lora_id: int) -> bool:
88+
return self.driver_worker.pin_lora(lora_id)
89+
8790
def list_loras(self) -> Set[int]:
8891
return self.driver_worker.list_loras()
8992

vllm/executor/distributed_gpu_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ def remove_lora(self, lora_id: int) -> bool:
100100
lora_id=lora_id,
101101
)
102102

103+
def pin_lora(self, lora_id: int) -> bool:
104+
assert lora_id > 0, "lora_id must be greater than 0."
105+
return self._run_workers(
106+
"pin_lora",
107+
lora_id=lora_id,
108+
)
109+
103110
def list_loras(self) -> Set[int]:
104111
return self._run_workers("list_loras")
105112

vllm/executor/executor_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
8686
def remove_lora(self, lora_id: int) -> bool:
8787
raise NotImplementedError
8888

89+
@abstractmethod
90+
def pin_lora(self, lora_id: int) -> bool:
91+
raise NotImplementedError # type: ignore
92+
8993
@abstractmethod
9094
def list_loras(self) -> Set[int]:
9195
raise NotImplementedError

vllm/executor/gpu_executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ def remove_lora(self, lora_id: int) -> bool:
9999
assert lora_id > 0, "lora_id must be greater than 0."
100100
return self.driver_worker.remove_lora(lora_id)
101101

102+
def pin_lora(self, lora_id: int) -> bool:
103+
assert lora_id > 0, "lora_id must be greater than 0."
104+
return self.driver_worker.pin_lora(lora_id)
105+
102106
def list_loras(self) -> Set[int]:
103107
return self.driver_worker.list_loras()
104108

vllm/executor/neuron_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
6565
def remove_lora(self, lora_id: int) -> bool:
6666
return self.driver_worker.remove_lora(lora_id)
6767

68+
def pin_lora(self, lora_id: int) -> bool:
69+
return self.driver_worker.pin_lora(lora_id)
70+
6871
def list_loras(self) -> Set[int]:
6972
return self.driver_worker.list_loras()
7073

vllm/lora/models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,12 @@ def remove_lora(self, lora_id: int) -> bool:
525525
self.long_lora_context.offsets_by_lora_id.pop(lora_id, None)
526526
return bool(self._registered_loras.pop(lora_id, None))
527527

528+
def pin_lora(self, lora_id: int) -> bool:
529+
"""Pin a LoRAModel in the manager cache."""
530+
raise NotImplementedError(
531+
"Pinning is not supported in LoRAModelManager."
532+
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
533+
528534
# TODO see if this can be vectorized
529535
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
530536
(base_indices, sampler_indices, sampler_indices_padded,
@@ -777,6 +783,26 @@ def remove_oldest_lora(self) -> bool:
777783
return True
778784
return False
779785

786+
def pin_lora(self, lora_id: int) -> bool:
787+
"""Pin a LoRAModel in the manager cache."""
788+
self._pin_lora_in_cpu_cache(lora_id)
789+
self._pin_lora_in_gpu_cache(lora_id)
790+
return True
791+
792+
def _pin_lora_in_cpu_cache(self, lora_id: int):
793+
try:
794+
self._registered_loras.pin(lora_id)
795+
except ValueError as err:
796+
raise ValueError("Pinning failed. "
797+
f"LoRA {lora_id} is not registered.") from err
798+
799+
def _pin_lora_in_gpu_cache(self, lora_id: int):
800+
if lora_id not in self._active_loras:
801+
# move lora to gpu if not already active
802+
self.activate_lora(lora_id)
803+
804+
self._active_loras.pin(lora_id)
805+
780806

781807
def create_lora_manager(
782808
model: nn.Module,

vllm/lora/worker_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
221221
def remove_lora(self, lora_id: int) -> bool:
222222
return self._lora_manager.remove_lora(lora_id)
223223

224+
def pin_lora(self, lora_id: int) -> bool:
225+
return self._lora_manager.pin_lora(lora_id)
226+
224227
def remove_all_loras(self):
225228
self._lora_manager.remove_all_loras()
226229

vllm/utils.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from functools import lru_cache, partial, wraps
1616
from platform import uname
1717
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
18-
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
18+
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
1919
Union)
2020

2121
import numpy as np
@@ -44,6 +44,13 @@
4444
T = TypeVar("T")
4545

4646

47+
class _Sentinel:
48+
...
49+
50+
51+
ALL_PINNED_SENTINEL = _Sentinel()
52+
53+
4754
class Device(enum.Enum):
4855
GPU = enum.auto()
4956
CPU = enum.auto()
@@ -67,6 +74,7 @@ class LRUCache(Generic[T]):
6774

6875
def __init__(self, capacity: int):
6976
self.cache: OrderedDict[Hashable, T] = OrderedDict()
77+
self.pinned_items: Set[Hashable] = set()
7078
self.capacity = capacity
7179

7280
def __contains__(self, key: Hashable) -> bool:
@@ -102,14 +110,36 @@ def put(self, key: Hashable, value: T) -> None:
102110
self.cache.move_to_end(key)
103111
self._remove_old_if_needed()
104112

113+
def pin(self, key: Hashable) -> None:
114+
"""
115+
Pins a key in the cache preventing it from being
116+
evicted in the LRU order.
117+
"""
118+
if key not in self.cache:
119+
raise ValueError(f"Cannot pin key: {key} not in cache.")
120+
self.pinned_items.add(key)
121+
122+
def _unpin(self, key: Hashable) -> None:
123+
self.pinned_items.remove(key)
124+
105125
def _on_remove(self, key: Hashable, value: Optional[T]):
106126
pass
107127

108-
def remove_oldest(self):
128+
def remove_oldest(self, remove_pinned=False):
109129
if not self.cache:
110130
return
111-
key, value = self.cache.popitem(last=False)
112-
self._on_remove(key, value)
131+
132+
if not remove_pinned:
133+
# pop the oldest item in the cache that is not pinned
134+
lru_key = next(
135+
(key for key in self.cache if key not in self.pinned_items),
136+
ALL_PINNED_SENTINEL)
137+
if lru_key is ALL_PINNED_SENTINEL:
138+
raise RuntimeError("All items are pinned, "
139+
"cannot remove oldest from the cache.")
140+
else:
141+
lru_key = next(iter(self.cache))
142+
self.pop(lru_key)
113143

114144
def _remove_old_if_needed(self) -> None:
115145
while len(self.cache) > self.capacity:
@@ -120,13 +150,16 @@ def pop(self,
120150
default_value: Optional[T] = None) -> Optional[T]:
121151
run_on_remove = key in self.cache
122152
value: Optional[T] = self.cache.pop(key, default_value)
153+
# remove from pinned items
154+
if key in self.pinned_items:
155+
self._unpin(key)
123156
if run_on_remove:
124157
self._on_remove(key, value)
125158
return value
126159

127160
def clear(self):
128161
while len(self.cache) > 0:
129-
self.remove_oldest()
162+
self.remove_oldest(remove_pinned=True)
130163
self.cache.clear()
131164

132165

vllm/worker/model_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,11 @@ def remove_lora(self, lora_id: int) -> bool:
878878
raise RuntimeError("LoRA is not enabled.")
879879
return self.lora_manager.remove_lora(lora_id)
880880

881+
def pin_lora(self, lora_id: int) -> bool:
882+
if not self.lora_manager:
883+
raise RuntimeError("LoRA is not enabled.")
884+
return self.lora_manager.pin_lora(lora_id)
885+
881886
def list_loras(self) -> Set[int]:
882887
if not self.lora_manager:
883888
raise RuntimeError("LoRA is not enabled.")

vllm/worker/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
333333
def remove_lora(self, lora_id: int) -> bool:
334334
return self.model_runner.remove_lora(lora_id)
335335

336+
def pin_lora(self, lora_id: int) -> bool:
337+
return self.model_runner.pin_lora(lora_id)
338+
336339
def list_loras(self) -> Set[int]:
337340
return self.model_runner.list_loras()
338341

vllm/worker/worker_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
7070
def remove_lora(self, lora_id: int) -> bool:
7171
raise NotImplementedError
7272

73+
@abstractmethod
74+
def pin_lora(self, lora_id: int) -> bool:
75+
raise NotImplementedError
76+
7377
@abstractmethod
7478
def list_loras(self) -> Set[int]:
7579
raise NotImplementedError
@@ -86,6 +90,10 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
8690
def remove_lora(self, lora_id: int) -> bool:
8791
raise ValueError(f"{type(self)} does not support LoRA")
8892

93+
def pin_lora(self, lora_id: int) -> bool:
94+
return ValueError(
95+
f"{type(self)} does not support LoRA") # type: ignore
96+
8997
def list_loras(self) -> Set[int]:
9098
raise ValueError(f"{type(self)} does not support LoRA")
9199

0 commit comments

Comments
 (0)