Skip to content

Commit b1a0ebb

Browse files
committed
v1: Introduce an offloading component
This commit adds a new offloading component, composed of: 1. A scheduler side OffloadingManager (abstract) which kicks-off KV data transfers and keeps track of offloaded data. 2. A worker side OffloadingWorker which asynchronously manages KV transfers. Signed-off-by: Or Ozeri <oro@il.ibm.com>
1 parent 6adaed4 commit b1a0ebb

File tree

5 files changed

+506
-0
lines changed

5 files changed

+506
-0
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ steps:
261261
# split the test to avoid interference
262262
- pytest -v -s v1/core
263263
- pytest -v -s v1/executor
264+
- pytest -v -s v1/offloading
264265
- pytest -v -s v1/sample
265266
- pytest -v -s v1/logits_processors
266267
- pytest -v -s v1/worker

tests/v1/offloading/test_worker.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from vllm.v1.offloading.abstract import LoadStoreSpec
4+
from vllm.v1.offloading.worker.worker import (OffloadingHandler,
5+
OffloadingWorker, TransferResult,
6+
TransferSpec)
7+
8+
9+
class LoadStoreSpec1(LoadStoreSpec):
10+
11+
def __init__(self,
12+
submit_success: bool = True,
13+
async_success: bool = True,
14+
exception: bool = False):
15+
self.finished = False
16+
self.submit_success = submit_success
17+
self.async_success = async_success
18+
self.exception = exception
19+
20+
@staticmethod
21+
def medium() -> str:
22+
return "1"
23+
24+
def __repr__(self):
25+
return f"{self.medium()}: {id(self)}"
26+
27+
28+
class LoadStoreSpec2(LoadStoreSpec):
29+
30+
@staticmethod
31+
def medium() -> str:
32+
return "2"
33+
34+
def __repr__(self):
35+
return f"{self.medium()}: {id(self)}"
36+
37+
38+
class OffloadingHandler1To2(OffloadingHandler):
39+
40+
def __init__(self):
41+
self.transfers: dict[int, LoadStoreSpec1] = {}
42+
43+
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
44+
srcs, dsts = spec
45+
assert len(srcs) == 1
46+
assert len(dsts) == 1
47+
48+
src, dst = srcs[0], dsts[0]
49+
assert isinstance(src, LoadStoreSpec1)
50+
assert isinstance(dst, LoadStoreSpec2)
51+
52+
if src.exception:
53+
raise Exception("An expected exception. Don't worry!")
54+
if not src.submit_success:
55+
return False
56+
57+
self.transfers[job_id] = src
58+
return True
59+
60+
def get_finished(self) -> list[TransferResult]:
61+
finished = []
62+
for job_id, spec in list(self.transfers.items()):
63+
if spec.finished:
64+
finished.append((job_id, spec.async_success))
65+
del self.transfers[job_id]
66+
return finished
67+
68+
69+
class OffloadingHandler2To1(OffloadingHandler):
70+
71+
def __init__(self):
72+
self.transfers: dict[int, LoadStoreSpec1] = {}
73+
74+
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
75+
srcs, dsts = spec
76+
assert len(srcs) == 1
77+
assert len(dsts) == 1
78+
79+
src, dst = srcs[0], dsts[0]
80+
assert isinstance(src, LoadStoreSpec2)
81+
assert isinstance(dst, LoadStoreSpec1)
82+
83+
self.transfers[job_id] = dst
84+
return True
85+
86+
def get_finished(self) -> list[TransferResult]:
87+
finished = []
88+
for job_id, spec in list(self.transfers.items()):
89+
if spec.finished:
90+
finished.append((job_id, spec.async_success))
91+
del self.transfers[job_id]
92+
return finished
93+
94+
95+
def test_offloading_worker():
96+
"""
97+
Tests OffloadingWorker with 2 handlers.
98+
One handler performs 1->2 transfers, and the other handles 2->1.
99+
"""
100+
worker = OffloadingWorker()
101+
handler1to2 = OffloadingHandler1To2()
102+
handler2to1 = OffloadingHandler2To1()
103+
worker.register_handler(LoadStoreSpec1, LoadStoreSpec2, handler1to2)
104+
worker.register_handler(LoadStoreSpec2, LoadStoreSpec1, handler2to1)
105+
106+
# 1st transfer 1->2 (exception)
107+
src1 = LoadStoreSpec1(exception=True)
108+
dst1 = LoadStoreSpec2()
109+
assert not worker.transfer_async(1, ([src1], [dst1]))
110+
111+
# 2ed transfer 1->2 (failure to submit)
112+
src2 = LoadStoreSpec1(submit_success=False)
113+
dst2 = LoadStoreSpec2()
114+
assert not worker.transfer_async(2, ([src2], [dst2]))
115+
116+
# 3rd transfer 1->2 (failure)
117+
src3 = LoadStoreSpec1(async_success=False)
118+
dst3 = LoadStoreSpec2()
119+
assert worker.transfer_async(3, ([src3], [dst3]))
120+
121+
# 4th transfer 1->2 (success)
122+
src4 = LoadStoreSpec1()
123+
dst4 = LoadStoreSpec2()
124+
worker.transfer_async(4, ([src4], [dst4]))
125+
assert set(handler1to2.transfers.keys()) == {3, 4}
126+
127+
# 5th transfer 2->1
128+
src5 = LoadStoreSpec2()
129+
dst5 = LoadStoreSpec1()
130+
worker.transfer_async(5, ([src5], [dst5]))
131+
assert set(handler2to1.transfers.keys()) == {5}
132+
133+
# no transfer completed yet
134+
assert worker.get_finished() == []
135+
136+
# complete 3rd, 4th
137+
src3.finished = True
138+
src4.finished = True
139+
140+
# 6th transfer 1->2
141+
src6 = LoadStoreSpec1()
142+
dst6 = LoadStoreSpec2()
143+
worker.transfer_async(6, ([src6], [dst6]))
144+
145+
# 7th transfer 2->1
146+
src7 = LoadStoreSpec2()
147+
dst7 = LoadStoreSpec1()
148+
worker.transfer_async(7, ([src7], [dst7]))
149+
150+
# 6th and 7th transfers started
151+
assert 6 in handler1to2.transfers
152+
assert 7 in handler2to1.transfers
153+
154+
# verify result of 3rd and 4th transfers
155+
assert (sorted(worker.get_finished()) == [(3, False), (4, True)])
156+
157+
# complete 6th and 7th transfers
158+
src6.finished = True
159+
dst7.finished = True
160+
assert (sorted(worker.get_finished()) == [(6, True), (7, True)])

vllm/v1/offloading/abstract.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
OffloadingManager class for managing KV data offloading in vLLM v1
5+
6+
This class runs in the scheduler, tracks which blocks are offloaded
7+
and their address.
8+
9+
The class provides the following primitives:
10+
lookup() - find the length of the maximal series of blocks,
11+
starting from the first one, that are all offloaded.
12+
parepare_load() - prepare given blocks to be read.
13+
This given blocks will be protected from eviction.
14+
This function returns a LoadSpec which encapsulates
15+
information required for performing the load.
16+
touch() - marks the give blocks as recently used. Can be used
17+
to track block's LRU. This function is separated from the
18+
prepare_load function to allow setting block recency even
19+
for blocks which do not need reading from the cache, such as
20+
blocks that are cached by the GPU prefix cache.
21+
complete_load() - mark blocks which were previously prepared to be
22+
loaded as done loading. This is to re-allow their eviction.
23+
prepare_store() - prepare the given blocks to be written.
24+
Returns a StoreSpec encapsulating offloading information,
25+
as well as a list of blocks that were evicted as a result.
26+
complete_store() - marks a previous store as completed.
27+
Following this call, the given blocks will become loadable.
28+
"""
29+
30+
from abc import ABC, abstractmethod
31+
from collections.abc import Iterable
32+
from dataclasses import dataclass
33+
from typing import Optional
34+
35+
36+
class LoadStoreSpec(ABC):
37+
"""
38+
Abstract metadata that encapsulates information allowing a worker
39+
to load, and optionally also to store, a block of KV data.
40+
"""
41+
42+
@staticmethod
43+
@abstractmethod
44+
def medium() -> str:
45+
"""
46+
Returns a string representation of the medium type
47+
this store/load targets.
48+
"""
49+
pass
50+
51+
52+
@dataclass
53+
class PrepareStoreOutput:
54+
block_hashes_to_store: list[int]
55+
store_specs: list[LoadStoreSpec]
56+
block_hashes_evicted: list[int]
57+
58+
59+
@dataclass
60+
class OffloadingEvent:
61+
block_hashes: list[int]
62+
block_size: int
63+
medium: str
64+
# True if blocks are removed, False if stored
65+
removed: bool
66+
67+
68+
class OffloadingManager(ABC):
69+
70+
@abstractmethod
71+
def lookup(self, block_hashes: list[int]) -> int:
72+
"""
73+
Finds the length of the maximal series of blocks, starting from the
74+
first one, that are all offloaded.
75+
76+
Args:
77+
block_hashes: the hashes identifying the blocks to lookup.
78+
79+
Returns:
80+
An integer representing the maximal number of blocks that
81+
are currently offloaded.
82+
"""
83+
pass
84+
85+
@abstractmethod
86+
def prepare_load(self, block_hashes: list[int]) -> list[LoadStoreSpec]:
87+
"""
88+
Prepare the given blocks to be read.
89+
The given blocks will be protected from eviction until
90+
complete_load is called.
91+
It assumes all given blocks are offloaded.
92+
93+
Args:
94+
block_hashes: the hashes identifying the blocks.
95+
96+
Returns:
97+
A list of LoadStoreSpec, one per each block, that can be used by
98+
a worker to locate and load the actual offloaded KV data.
99+
"""
100+
pass
101+
102+
def touch(self, block_hashes: list[int]):
103+
"""
104+
Mark the given blocks as recently used.
105+
This could in practice mean moving them to the end of an LRU list.
106+
107+
Args:
108+
block_hashes: the hashes identifying the blocks.
109+
"""
110+
return
111+
112+
def complete_load(self, block_hashes: list[int]):
113+
"""
114+
Marks previous blocks that were prepared to load as done loading.
115+
116+
Args:
117+
block_hashes: the hashes identifying the blocks.
118+
"""
119+
return
120+
121+
@abstractmethod
122+
def prepare_store(self,
123+
block_hashes: list[int]) -> Optional[PrepareStoreOutput]:
124+
"""
125+
Prepare the given blocks to be offloaded.
126+
The given blocks will be protected from eviction until
127+
complete_store is called.
128+
129+
Args:
130+
block_hashes: the hashes identifying the blocks.
131+
132+
Returns:
133+
A PrepareStoreOutput indicating which blocks need storing,
134+
where to store them (LoadStoreSpec), and list of blocks that
135+
were evicted as a result.
136+
None is returned if the blocks cannot be stored.
137+
"""
138+
pass
139+
140+
def complete_store(self, block_hashes: list[int], success: bool = True):
141+
"""
142+
Marks blocks which were previously prepared to be stored, as stored.
143+
Following this call, the blocks become loadable.
144+
If if_success is False, blocks that were not marked as stored will be
145+
removed.
146+
147+
Args:
148+
block_hashes: the hashes identifying the blocks.
149+
success: whether the blocks were stored successfully.
150+
"""
151+
return
152+
153+
def take_events(self) -> Iterable[OffloadingEvent]:
154+
"""
155+
Take the offloading events from the manager.
156+
157+
Yields:
158+
New OffloadingEvents collected since the last call.
159+
"""
160+
return ()

vllm/v1/offloading/mediums.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from abc import ABC
4+
5+
from vllm.v1.offloading.abstract import LoadStoreSpec
6+
7+
8+
class BlockIDLoadStoreSpec(LoadStoreSpec, ABC):
9+
"""
10+
Spec for loading/storing a KV block from a given block number.
11+
"""
12+
13+
def __init__(self, block_id: int):
14+
self.block_id = block_id
15+
16+
def __repr__(self) -> str:
17+
return str(self.block_id)
18+
19+
20+
class GPULoadStoreSpec(BlockIDLoadStoreSpec):
21+
"""
22+
Spec for loading/storing a KV block to GPU memory.
23+
"""
24+
25+
@staticmethod
26+
def medium() -> str:
27+
return "GPU"
28+
29+
30+
class CPULoadStoreSpec(BlockIDLoadStoreSpec):
31+
"""
32+
Spec for loading/storing a KV block to CPU memory.
33+
"""
34+
35+
@staticmethod
36+
def medium() -> str:
37+
return "CPU"

0 commit comments

Comments
 (0)