Skip to content

Commit 0c63c34

Browse files
[Bugfix][SpecDecode] kv corruption with bonus tokens in spec decode (#9730)
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
1 parent 966e316 commit 0c63c34

File tree

4 files changed

+159
-10
lines changed

4 files changed

+159
-10
lines changed

tests/spec_decode/test_multi_step_worker.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import pytest
66
import torch
77

8+
from vllm.attention.selector import (_Backend,
9+
global_force_attn_backend_context_manager)
810
from vllm.model_executor.layers.sampler import SamplerOutput
911
from vllm.model_executor.utils import set_random_seed
1012
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
@@ -303,6 +305,7 @@ def test_multi_step_with_batch_expansion_correct_output():
303305
seed,
304306
model_runner_cls=TP1DraftModelRunner,
305307
)
308+
multi_step_worker.set_include_gpu_probs_tensor()
306309
worker = create_worker(
307310
Worker,
308311
model_name,
@@ -397,6 +400,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
397400
seed,
398401
model_runner_cls=TP1DraftModelRunner,
399402
)
403+
multi_step_worker.set_include_gpu_probs_tensor()
400404
worker = create_worker(
401405
Worker,
402406
model_name,
@@ -477,6 +481,109 @@ def test_multi_step_with_batch_expansion_incorrect_output():
477481
assert (num_mismatch > 0)
478482

479483

484+
@torch.inference_mode()
485+
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
486+
# The choice of backends forces the multi_step_worker to choose between
487+
# the vanilla model_runner and TP1DraftModelRunner and that we can test
488+
# both code paths.
489+
@pytest.mark.parametrize('attn_backend',
490+
[_Backend.XFORMERS, _Backend.FLASH_ATTN])
491+
def test_multi_step_correct_kvcache(num_steps, attn_backend):
492+
"""Verify that the KV cache of the draft model
493+
is correctly updated for sequences with bonus token.
494+
"""
495+
seed = 100
496+
model_name = "JackFram/llama-68m"
497+
498+
block_size = 16
499+
num_gpu_blocks = 2048 // block_size
500+
batch_size = 1
501+
502+
with global_force_attn_backend_context_manager(attn_backend):
503+
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
504+
multi_step_worker = create_worker(MultiStepWorker,
505+
model_name,
506+
block_size,
507+
num_gpu_blocks,
508+
seed,
509+
model_runner_cls=TP1DraftModelRunner,
510+
dtype=dtype)
511+
multi_step_worker.set_include_gpu_probs_tensor()
512+
worker = create_worker(Worker,
513+
model_name,
514+
block_size,
515+
num_gpu_blocks,
516+
seed,
517+
dtype=dtype)
518+
519+
prompts = [[0] for _ in range(batch_size)]
520+
# Already generate two tokens for the sequence
521+
# so that we can simulate the bonus token case
522+
multi_step_continuations = [[
523+
random.randint(0, 1000),
524+
random.randint(0, 1000)
525+
] for _ in prompts]
526+
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
527+
528+
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
529+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
530+
prompts,
531+
num_gpu_blocks,
532+
block_size,
533+
continuations=multi_step_continuations,
534+
final_prompt_lens=final_prompt_lens)
535+
536+
# Run multi-step.
537+
zero_kv_cache(multi_step_worker.cache_engine)
538+
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
539+
seq_group_metadata_list=seq_group_metadata_list),
540+
sample_len=num_steps,
541+
seq_ids_with_bonus_token_in_last_step=
542+
seq_ids_with_bonus_token_in_last_step)
543+
544+
# Run single-step repeatedly.
545+
zero_kv_cache(worker.cache_engine)
546+
# Generate the kv cache for the bonus token first
547+
single_step_continuations = [c[:1] for c in multi_step_continuations]
548+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
549+
prompts,
550+
num_gpu_blocks,
551+
block_size,
552+
continuations=single_step_continuations,
553+
final_prompt_lens=final_prompt_lens)
554+
single_step_output = worker.execute_model(
555+
execute_model_req=ExecuteModelRequest(
556+
seq_group_metadata_list=seq_group_metadata_list))
557+
for _ in range(num_steps):
558+
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
559+
prompts,
560+
num_gpu_blocks,
561+
block_size,
562+
continuations=multi_step_continuations,
563+
final_prompt_lens=final_prompt_lens)
564+
565+
single_step_output = worker.execute_model(
566+
execute_model_req=ExecuteModelRequest(
567+
seq_group_metadata_list=seq_group_metadata_list))
568+
569+
for i, seq_group_output in enumerate(single_step_output[-1]):
570+
multi_step_continuations[i].append(
571+
seq_group_output.samples[0].output_token)
572+
573+
# Verify that the KV cache of the single-step and
574+
# multi-step workers are the same.
575+
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
576+
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
577+
num_layers = len(single_step_gpu_cache)
578+
allclose = lambda a, b: torch.allclose(
579+
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
580+
for i in range(num_layers):
581+
assert allclose(single_step_gpu_cache[i][0],
582+
multi_step_gpu_cache[i][0])
583+
assert allclose(single_step_gpu_cache[i][1],
584+
multi_step_gpu_cache[i][1])
585+
586+
480587
@torch.inference_mode()
481588
def test_draft_proposals_full_speculation_len():
482589
"""Verify Top1Proposer correctly handles case where all sequences

tests/spec_decode/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,14 @@ def create_worker(cls: Callable[..., T],
6868
seed: int,
6969
is_driver_worker: bool = True,
7070
enforce_eager: bool = True,
71-
model_runner_cls: Optional[ModelRunner] = None) -> T:
71+
model_runner_cls: Optional[ModelRunner] = None,
72+
dtype: Optional[str] = "auto") -> T:
7273
engine_args = EngineArgs(
7374
model=model_name,
7475
seed=seed,
7576
block_size=block_size,
7677
enforce_eager=enforce_eager,
78+
dtype=dtype,
7779
)
7880
engine_config = engine_args.create_engine_config()
7981

vllm/spec_decode/draft_model_runner.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def __init__(self, *args, **kwargs):
5454

5555
super().__init__(*args, **kwargs)
5656

57+
self.indices_of_seq_with_bonus_tokens = None
58+
5759
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
5860
num_queries):
5961

@@ -159,6 +161,10 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
159161
# TODO: Add soft-tuning prompt adapter support
160162
return not self.prompt_adapter_config
161163

164+
def set_indices_of_seq_with_bonus_tokens(self,
165+
indices_of_seq_with_bonus_tokens):
166+
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
167+
162168
@torch.inference_mode()
163169
def execute_model(
164170
self,
@@ -284,11 +290,30 @@ def execute_model(
284290
model_input.sampling_metadata)
285291

286292
# Sample the next token.
287-
outputs.append(
288-
self.model.sample(
289-
logits=logits,
290-
sampling_metadata=model_input.sampling_metadata,
291-
))
293+
output = self.model.sample(
294+
logits=logits,
295+
sampling_metadata=model_input.sampling_metadata,
296+
)
297+
outputs.append(output)
298+
299+
if model_input.attn_metadata.num_prefills == 0 \
300+
and self.indices_of_seq_with_bonus_tokens is not None:
301+
assert output.sampled_token_ids is not None
302+
# output.sampled_token_ids should be of shape (num_seqs, 1)
303+
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
304+
assert num_tokens_per_seq == 1
305+
count = 0
306+
for i in range(nums_seqs):
307+
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
308+
count]
309+
if i != bonus_seq_idx:
310+
# The following might cause a cpu->gpu sync
311+
# However, the performance impact is negligible as we
312+
# benchmarked on H100.
313+
output.sampled_token_ids[
314+
i, :] = model_input.input_tokens[bonus_seq_idx]
315+
else:
316+
count += 1
292317

293318
# Prepare inputs for the next step
294319
if step != num_steps - 1:

vllm/spec_decode/multi_step_worker.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def sampler_output(
8181
# Here we run the draft_model_runner with multi-step prepare
8282
# on the GPU directly
8383
expanded_request.num_steps = sample_len
84+
self.model_runner.set_indices_of_seq_with_bonus_tokens(
85+
indices_of_seq_with_bonus_tokens)
8486
model_outputs = self.execute_model(
8587
execute_model_req=expanded_request)
8688
else:
@@ -97,7 +99,8 @@ def sampler_output(
9799
model_output = model_output[0]
98100

99101
self._append_new_tokens(
100-
model_output, expanded_request.seq_group_metadata_list)
102+
model_output, expanded_request.seq_group_metadata_list,
103+
indices_of_seq_with_bonus_tokens)
101104
model_outputs.append(model_output)
102105

103106
filtered_model_outputs = self._filter_model_output(
@@ -221,13 +224,15 @@ def get_spec_proposals(
221224
@staticmethod
222225
def _append_new_tokens(
223226
model_output: List[SamplerOutput],
224-
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
227+
seq_group_metadata_list: List[SequenceGroupMetadata],
228+
indices_of_seq_with_bonus_tokens: List[int]) -> None:
225229
"""Given model output from a single run, append the tokens to the
226230
sequences. This is normally done outside of the worker, but it is
227231
required if the worker is to perform multiple forward passes.
228232
"""
229-
for seq_group_metadata, sequence_group_outputs in zip(
230-
seq_group_metadata_list, model_output):
233+
count = 0
234+
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
235+
zip(seq_group_metadata_list, model_output)):
231236
seq_group_metadata.is_prompt = False
232237

233238
for seq_output in sequence_group_outputs.samples:
@@ -237,6 +242,16 @@ def _append_new_tokens(
237242

238243
token_id = seq_output.output_token
239244
token_logprob = seq_output.logprobs[token_id]
245+
# Determine the actual token ID to be generated,
246+
# considering bonus tokens
247+
if index != indices_of_seq_with_bonus_tokens[count]:
248+
bonus_seq_metadata = seq_group_metadata_list[
249+
indices_of_seq_with_bonus_tokens[count]]
250+
_, bonus_token_seq_data = next(
251+
iter(bonus_seq_metadata.seq_data.items()))
252+
token_id = bonus_token_seq_data.output_token_ids[-1]
253+
else:
254+
count += 1
240255

241256
seq.append_token_id(token_id, token_logprob.logprob)
242257
seq.update_num_computed_tokens(1)

0 commit comments

Comments
 (0)