Skip to content

Commit 93deb0b

Browse files
authored
[Speculative decoding 4/9] Lookahead scheduling for speculative decoding (#3250)
1 parent ccb58b2 commit 93deb0b

13 files changed

+579
-123
lines changed

tests/core/block/e2e/test_correctness.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
7777
assert baseline_token_ids == test_token_ids
7878

7979

80+
@pytest.mark.parametrize(
81+
"common_llm_kwargs",
82+
[{
83+
# Use a small model for a fast test.
84+
"model": "facebook/opt-125m",
85+
86+
# skip cuda graph creation for fast test.
87+
"enforce_eager": True,
88+
89+
# Use a large block size to trigger more copy-on-writes.
90+
"block_size": 32,
91+
}])
92+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
93+
@pytest.mark.parametrize("baseline_llm_kwargs", [{
94+
"use_v2_block_manager": False
95+
}])
96+
@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}])
97+
@pytest.mark.parametrize("batch_size", [10])
98+
@pytest.mark.parametrize("seed", [1])
99+
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
100+
test_llm_generator, batch_size):
101+
"""Verify beam search equality with block manager v1 and v2.
102+
103+
This requires copy-on-writes; if the v1 and v2 output is the same, then
104+
we have some confidence cow is working.
105+
"""
106+
output_len = 128
107+
temperature = 0.0
108+
109+
prompts = [
110+
"Hello, my name is",
111+
"The president of the United States is",
112+
"The capital of France is",
113+
"The future of AI is",
114+
]
115+
116+
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
117+
118+
sampling_params = SamplingParams(
119+
max_tokens=output_len,
120+
ignore_eos=True,
121+
temperature=temperature,
122+
use_beam_search=True,
123+
best_of=2,
124+
)
125+
126+
print('Getting token ids from block manager v1')
127+
baseline_token_ids = get_token_ids_from_llm_generator(
128+
baseline_llm_generator, prompts, sampling_params)
129+
130+
print('Getting token ids from block manager v2')
131+
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
132+
prompts, sampling_params)
133+
134+
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
135+
test_token_ids):
136+
assert expected_token_ids == actual_token_ids
137+
138+
assert baseline_token_ids == test_token_ids
139+
140+
141+
@pytest.mark.parametrize(
142+
"common_llm_kwargs",
143+
[{
144+
# Use a small model for a fast test.
145+
"model": "facebook/opt-125m",
146+
147+
# Our prompts will generate 128 tokens; since the prompts themselves are
148+
# small, we don't need much KV space beyond 128.
149+
"max_model_len": 160,
150+
151+
# skip cuda graph creation for fast test.
152+
"enforce_eager": True,
153+
154+
# Lookahead scheduling only supported in v2 block manager.
155+
"use_v2_block_manager": True,
156+
}])
157+
@pytest.mark.parametrize(
158+
"per_test_common_llm_kwargs",
159+
[
160+
{
161+
"block_size": 16,
162+
163+
# Allow only 2 sequences of ~128 tokens in worst case.
164+
# Note 8 = 128/block_size
165+
"forced_num_gpu_blocks": 2 * (8 + 1),
166+
},
167+
{
168+
"block_size": 8,
169+
170+
# Allow only 2 sequences of ~128 tokens in worst case.
171+
# Note 16 = 128/block_size
172+
"forced_num_gpu_blocks": 2 * (16 + 1),
173+
}
174+
])
175+
@pytest.mark.parametrize("baseline_llm_kwargs", [{
176+
"num_lookahead_slots": 0,
177+
}])
178+
@pytest.mark.parametrize(
179+
"test_llm_kwargs",
180+
[{
181+
# We run one test with block_size < lookahead_slots, one test with
182+
# block_size > lookahead_slots
183+
"num_lookahead_slots": 10,
184+
}])
185+
@pytest.mark.parametrize("batch_size", [4])
186+
@pytest.mark.parametrize("seed", [1])
187+
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
188+
test_llm_generator,
189+
batch_size):
190+
"""Verify vLLM produces the same output with greedy sampling, when lookahead
191+
scheduling is used vs. not.
192+
193+
Lookahead scheduling is not expected to modify the output, as it simply
194+
allocates empty slots ahead of the known token ids in a sliding fashion.
195+
196+
This test constrains the total number of blocks to force preemption. It also
197+
varies the block size so that the lookahead size is less than and greater
198+
than the block size.
199+
"""
200+
output_len = 128
201+
temperature = 0.0
202+
203+
prompts = [
204+
"Hello, my name is",
205+
"The president of the United States is",
206+
"The capital of France is",
207+
"The future of AI is",
208+
]
209+
210+
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
211+
212+
sampling_params = SamplingParams(
213+
max_tokens=output_len,
214+
ignore_eos=True,
215+
temperature=temperature,
216+
)
217+
218+
print('Getting token ids without lookahead scheduling')
219+
baseline_token_ids = get_token_ids_from_llm_generator(
220+
baseline_llm_generator, prompts, sampling_params)
221+
222+
print('Getting token ids with lookahead scheduling')
223+
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
224+
prompts, sampling_params)
225+
226+
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
227+
test_token_ids):
228+
assert expected_token_ids == actual_token_ids
229+
230+
assert baseline_token_ids == test_token_ids
231+
232+
80233
def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
81234
for llm in llm_generator:
82235
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import pytest
2+
3+
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
4+
from vllm.core.interfaces import AllocStatus
5+
from vllm.sequence import Logprob, SequenceStatus
6+
from vllm.utils import chunk_list
7+
8+
from ..utils import create_seq_group
9+
10+
11+
@pytest.mark.parametrize("block_size", [16])
12+
@pytest.mark.parametrize("num_gpu_blocks", [8, 40, 80])
13+
@pytest.mark.parametrize("num_seqs_per_group", [1, 4])
14+
@pytest.mark.parametrize("watermark", [0.0, 0.5])
15+
def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
16+
num_gpu_blocks: int, watermark: float):
17+
block_manager = BlockSpaceManagerV2(
18+
block_size=block_size,
19+
num_gpu_blocks=num_gpu_blocks,
20+
num_cpu_blocks=1024,
21+
watermark=watermark,
22+
)
23+
num_watermark_blocks = int(watermark * num_gpu_blocks)
24+
25+
num_output_blocks_per_seq = 1
26+
27+
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
28+
# the current implementation assumes all seqs are new prompts / don't have
29+
# different output lens.
30+
num_output_blocks = num_output_blocks_per_seq
31+
32+
for num_prompt_blocks in range(1, num_gpu_blocks - num_output_blocks):
33+
seq_group = create_seq_group(
34+
seq_prompt_len=block_size * num_prompt_blocks,
35+
seq_output_lens=[
36+
block_size * num_output_blocks_per_seq
37+
for _ in range(num_seqs_per_group)
38+
],
39+
)
40+
41+
assert num_prompt_blocks + num_output_blocks <= num_gpu_blocks
42+
43+
can_allocate_result = block_manager.can_allocate(seq_group)
44+
45+
num_required_blocks = num_prompt_blocks + num_output_blocks
46+
47+
if num_gpu_blocks - num_required_blocks < num_watermark_blocks:
48+
assert can_allocate_result == AllocStatus.NEVER
49+
elif num_gpu_blocks >= num_required_blocks:
50+
assert can_allocate_result == AllocStatus.OK
51+
else:
52+
assert can_allocate_result == AllocStatus.LATER
53+
54+
55+
@pytest.mark.parametrize("block_size", [1, 8])
56+
@pytest.mark.parametrize("prompt_len", [1, 7, 8])
57+
@pytest.mark.parametrize("num_slots_to_append", [1, 8, 129])
58+
@pytest.mark.parametrize("num_lookahead_slots", [0, 10])
59+
def test_append_slots(block_size, prompt_len, num_slots_to_append,
60+
num_lookahead_slots):
61+
"""Verify append_slots consumes the correct number of blocks from the block
62+
table.
63+
"""
64+
65+
num_gpu_blocks = 1024
66+
watermark = 0.1
67+
block_manager = BlockSpaceManagerV2(
68+
block_size=block_size,
69+
num_gpu_blocks=num_gpu_blocks,
70+
num_cpu_blocks=0,
71+
watermark=watermark,
72+
)
73+
74+
seq_group = create_seq_group(
75+
seq_prompt_len=prompt_len,
76+
seq_output_lens=[0],
77+
)
78+
79+
# Allocate seq
80+
assert block_manager.can_allocate(seq_group)
81+
block_manager.allocate(seq_group)
82+
83+
# Seq seq to RUNNING
84+
seq = seq_group.get_seqs()[0]
85+
seq.status = SequenceStatus.RUNNING
86+
87+
# Append tokens to the sequeqnce
88+
for token_id in range(num_slots_to_append):
89+
seq.append_token_id(token_id, {token_id: Logprob(0.0)})
90+
91+
# Append slots for new tokens and lookahead slots.
92+
free_blocks_before_append = block_manager.get_num_free_gpu_blocks()
93+
block_manager.append_slots(seq, num_lookahead_slots)
94+
num_consumed_blocks = (free_blocks_before_append -
95+
block_manager.get_num_free_gpu_blocks())
96+
97+
# Expect consumed blocks to be new blocks required to support the new slots.
98+
expected_consumed_blocks = len(
99+
chunk_list(
100+
list(
101+
range(prompt_len + num_slots_to_append + num_lookahead_slots)),
102+
block_size)) - len(chunk_list(list(range(prompt_len)), block_size))
103+
assert num_consumed_blocks == expected_consumed_blocks

tests/core/block/test_block_space_manager.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

tests/core/block/test_block_table.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
498498

499499
# After free, expect all blocks to be freed.
500500
assert allocator.get_num_free_blocks(Device.GPU) == num_gpu_blocks
501+
502+
503+
@pytest.mark.parametrize("block_size", [1, 8])
504+
@pytest.mark.parametrize("sequence_len", [1, 16, 129])
505+
@pytest.mark.parametrize("num_new_tokens", [1, 16, 129])
506+
@pytest.mark.parametrize("num_lookahead_slots", [1, 7, 8])
507+
@pytest.mark.parametrize("allocator_type", ["naive", "prefix_caching"])
508+
def test_num_blocks_touched_by_append_slots(block_size: int, sequence_len: int,
509+
num_new_tokens: int,
510+
num_lookahead_slots: int,
511+
allocator_type: str):
512+
"""Verify correct calculation of get_num_blocks_touched_by_append_slots.
513+
514+
This is done by using copy-on-write, which requires any modified block to
515+
be copied before write if the refcount > 1. We set the refcount>1 by forking
516+
a sequence, then measure the free blocks before and after an append. If the
517+
number of consumed blocks equals what `get_num_blocks_touched_by_append_
518+
slots` returns, then the calculation is correct.
519+
"""
520+
521+
num_gpu_blocks = 1024
522+
523+
allocator = CpuGpuBlockAllocator.create(
524+
allocator_type=allocator_type,
525+
num_gpu_blocks=num_gpu_blocks,
526+
num_cpu_blocks=0,
527+
block_size=block_size,
528+
)
529+
530+
token_ids = list(range(sequence_len))
531+
token_ids_to_append = list(range(num_new_tokens))
532+
533+
block_table = BlockTable(
534+
block_size=block_size,
535+
block_allocator=allocator,
536+
)
537+
538+
block_table.allocate(token_ids=token_ids, device=Device.GPU)
539+
540+
# Add lookahead before fork so both sequences have the same lookahead
541+
# blocks.
542+
block_table.ensure_num_empty_slots(num_empty_slots=num_lookahead_slots)
543+
544+
# Fork sequence so that every block has refcount > 1.
545+
_ = block_table.fork()
546+
547+
# Determine how many blocks should be touched.
548+
expected_num_touched_blocks = (
549+
block_table.get_num_blocks_touched_by_append_slots(
550+
token_ids=token_ids_to_append,
551+
num_lookahead_slots=num_lookahead_slots))
552+
553+
# Measure how many blocks are touched by measuring num_free_blocks before
554+
# and after the append.
555+
#
556+
# We expect append_token_ids to CoW all mutated blocks that have refcount>1.
557+
num_free_blocks_before_append = allocator.get_num_free_blocks(Device.GPU)
558+
block_table.append_token_ids(token_ids_to_append, num_lookahead_slots)
559+
num_consumed_blocks = (num_free_blocks_before_append -
560+
allocator.get_num_free_blocks(Device.GPU))
561+
562+
# TODO(cade) ensure equality when num_lookahead_slots > 0.
563+
# The reason we have < is because lookahead blocks are not copied eagerly;
564+
# they are copied on first write. This will cause issues for beam search +
565+
# speculative decoding. This is acceptable for now as it is a large effort
566+
# to combine the two. To fix this, we can ensure single sequence ownership
567+
# of lookahead blocks by appending empty slots to each block, which will
568+
# trigger the CoW.
569+
#
570+
# Until then, we can accept that the consumed tokens are <= the expected
571+
# tokens when appending with lookahead.
572+
if num_lookahead_slots > 0:
573+
assert num_consumed_blocks <= expected_num_touched_blocks
574+
else:
575+
assert num_consumed_blocks == expected_num_touched_blocks

0 commit comments

Comments
 (0)