Skip to content

Commit d46dd92

Browse files
LiuXiaoxuanPKUerdaltoprak
authored andcommitted
[V1][Spec Decode] KV cache slots for eagle heads (vllm-project#16370)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com> Signed-off-by: Erdal Toprak <contact@erdaltoprak.com>
1 parent 6152c6c commit d46dd92

File tree

4 files changed

+98
-18
lines changed

4 files changed

+98
-18
lines changed

tests/v1/core/test_kv_cache_utils.py

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
88
from vllm.sampling_params import SamplingParams
99
from vllm.utils import GiB_bytes, sha256
10+
from vllm.v1.core.kv_cache_manager import KVCacheManager
1011
# disable yapf here as it formats differently than isort such that both fail
1112
# yapf: disable
1213
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
@@ -48,6 +49,18 @@ def make_request(request_id,
4849
)
4950

5051

52+
def new_kv_cache_spec(block_size=16,
53+
num_kv_heads=2,
54+
head_size=64,
55+
dtype=torch.float32,
56+
use_mla=False):
57+
return FullAttentionSpec(block_size=block_size,
58+
num_kv_heads=num_kv_heads,
59+
head_size=head_size,
60+
dtype=dtype,
61+
use_mla=use_mla)
62+
63+
5164
def test_none_hash():
5265
assert NONE_HASH is not None
5366
assert isinstance(NONE_HASH, int)
@@ -327,18 +340,6 @@ def stats(requests, queries, hits):
327340

328341

329342
def test_unify_kv_cache_configs():
330-
331-
def new_kv_cache_spec(block_size=16,
332-
num_kv_heads=2,
333-
head_size=64,
334-
dtype=torch.float32,
335-
use_mla=False):
336-
return FullAttentionSpec(block_size=block_size,
337-
num_kv_heads=num_kv_heads,
338-
head_size=head_size,
339-
dtype=dtype,
340-
use_mla=use_mla)
341-
342343
same_kv_cache_config = [
343344
KVCacheConfig(
344345
num_blocks=10,
@@ -470,3 +471,64 @@ def test_estimate_max_model_len(model_id, max_model_len,
470471
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
471472
8 * GiB_bytes)
472473
assert estimated_max_len == want_estimated_max_len
474+
475+
476+
def test_allocate_with_lookahead():
477+
"""Verify that lookahead tokens correctly affect block allocation"""
478+
block_size = 4
479+
config = KVCacheConfig(
480+
num_blocks=10,
481+
tensors={
482+
"layer1": KVCacheTensor(100),
483+
},
484+
kv_cache_groups=[
485+
KVCacheGroupSpec(["layer1"],
486+
new_kv_cache_spec(block_size=block_size)),
487+
],
488+
)
489+
490+
request = make_request(
491+
request_id=0,
492+
prompt_token_ids=[],
493+
mm_positions=None,
494+
mm_hashes=None,
495+
)
496+
497+
# Test case 1: Requires additional lookahead tokens
498+
kv_cache_manager = KVCacheManager(kv_cache_config=config,
499+
max_model_len=100,
500+
num_preallocate_tokens=0)
501+
blocks = kv_cache_manager.allocate_slots(
502+
request,
503+
num_tokens=3,
504+
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
505+
)
506+
assert len(blocks) == 2 # ceil(5/4)=2 blocks
507+
508+
# Test case 2: With precomputed blocks
509+
kv_cache_manager = KVCacheManager(kv_cache_config=config,
510+
max_model_len=100,
511+
num_preallocate_tokens=4)
512+
# num_preallocate_blocks = 4 // 4 - 2 // 4 = 1
513+
# required_blocks = ceil((3 + 2) /4) = 2
514+
# total_blocks = 1 + 2 = 3
515+
blocks = kv_cache_manager.allocate_slots(
516+
request,
517+
num_tokens=3,
518+
num_lookahead_tokens=2,
519+
)
520+
assert len(blocks) == 3
521+
522+
# Test case 3: With precomputed blocks
523+
# num_preallocate_blocks = 4 // 4 - 4 // 4 = 0
524+
# required_blocks = ceil((3 + 4) / 4) = 2
525+
# total_blocks = 0 + 2 = 2
526+
kv_cache_manager = KVCacheManager(kv_cache_config=config,
527+
max_model_len=100,
528+
num_preallocate_tokens=4)
529+
blocks = kv_cache_manager.allocate_slots(
530+
request,
531+
num_tokens=3,
532+
num_lookahead_tokens=4,
533+
)
534+
assert len(blocks) == 2

vllm/v1/core/kv_cache_manager.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@ def allocate_slots(
164164
self,
165165
request: Request,
166166
num_tokens: int,
167-
new_computed_blocks: Optional[list[KVCacheBlock]] = None
167+
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
168+
num_lookahead_tokens: int = 0,
168169
) -> Optional[list[KVCacheBlock]]:
169170
"""Add slots for a request with new tokens to append.
170171
@@ -174,6 +175,9 @@ def allocate_slots(
174175
not include the tokens that have already been computed.
175176
new_computed_blocks: A list of new computed blocks just hitting the
176177
prefix caching.
178+
num_lookahead_tokens: The number of speculative tokens to allocate.
179+
This is used by spec decode proposers with kv-cache such
180+
as eagle.
177181
178182
Blocks layout:
179183
-----------------------------------------------------------------------
@@ -211,8 +215,9 @@ def allocate_slots(
211215
# the new prefix caching hits
212216
num_computed_tokens = (request.num_computed_tokens +
213217
len(new_computed_blocks) * self.block_size)
214-
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
215-
self.block_size)
218+
num_required_blocks = cdiv(
219+
num_computed_tokens + num_tokens + num_lookahead_tokens,
220+
self.block_size)
216221
num_new_blocks = (num_required_blocks - len(req_blocks) -
217222
len(new_computed_blocks))
218223

@@ -246,8 +251,11 @@ def allocate_slots(
246251
else:
247252
# Get new blocks from the free block pool considering
248253
# preallocated blocks.
254+
num_preallocate_blocks = max(
255+
0, self.num_preallocate_blocks -
256+
num_lookahead_tokens // self.block_size)
249257
num_new_blocks = min(
250-
num_new_blocks + self.num_preallocate_blocks,
258+
num_new_blocks + num_preallocate_blocks,
251259
self.block_pool.get_num_free_blocks(),
252260
# Should not exceed the maximum number of blocks per request.
253261
# This is especially because the block table has the shape

vllm/v1/core/sched/scheduler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from collections.abc import Iterable
88
from typing import Optional, Union
99

10-
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
10+
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
11+
SpeculativeConfig)
1112
from vllm.logger import init_logger
1213
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
1314
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
@@ -39,6 +40,7 @@ def __init__(
3940
lora_config: Optional[LoRAConfig],
4041
kv_cache_config: KVCacheConfig,
4142
structured_output_manager: StructuredOutputManager,
43+
speculative_config: SpeculativeConfig = None,
4244
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
4345
include_finished_set: bool = False,
4446
log_stats: bool = False,
@@ -112,6 +114,11 @@ def __init__(
112114
self.encoder_cache_manager = EncoderCacheManager(
113115
cache_size=encoder_cache_size)
114116

117+
self.num_lookahead_tokens = 0
118+
if speculative_config and speculative_config.method == "eagle":
119+
self.num_lookahead_tokens = \
120+
speculative_config.num_speculative_tokens
121+
115122
def schedule(self) -> SchedulerOutput:
116123
# NOTE(woosuk) on the scheduling algorithm:
117124
# There's no "decoding phase" nor "prefill phase" in the scheduler.
@@ -188,7 +195,9 @@ def schedule(self) -> SchedulerOutput:
188195

189196
while True:
190197
new_blocks = self.kv_cache_manager.allocate_slots(
191-
request, num_new_tokens)
198+
request,
199+
num_new_tokens,
200+
num_lookahead_tokens=self.num_lookahead_tokens)
192201
if new_blocks is None:
193202
# The request cannot be scheduled.
194203
# Preempt the lowest-priority request.

vllm/v1/engine/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
cache_config=vllm_config.cache_config,
9999
lora_config=vllm_config.lora_config,
100100
kv_cache_config=kv_cache_config,
101+
speculative_config=vllm_config.speculative_config,
101102
structured_output_manager=self.structured_output_manager,
102103
include_finished_set=vllm_config.parallel_config.data_parallel_size
103104
> 1,

0 commit comments

Comments
 (0)