Skip to content

Commit 1009e93

Browse files
authored
[Encoder decoder] Add cuda graph support during decoding for encoder-decoder models (#7631)
1 parent 1b6de83 commit 1009e93

File tree

15 files changed

+526
-112
lines changed

15 files changed

+526
-112
lines changed

.buildkite/test-pipeline.yaml

+7
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ steps:
252252
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
253253
- bash ./run-tests.sh -c configs/models-small.txt -t 1
254254

255+
- label: Encoder Decoder tests # 5min
256+
source_file_dependencies:
257+
- vllm/
258+
- tests/encoder_decoder
259+
commands:
260+
- pytest -v -s encoder_decoder
261+
255262
- label: OpenAI-Compatible Tool Use # 20 min
256263
fast_check: false
257264
mirror_hardwares: [ amd ]

tests/encoder_decoder/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""E2E tests to verify the correctness of the encoder-decoder framework
2+
3+
Run `pytest tests/encoder_decoder/test_e2e_correctness.py`.
4+
"""
5+
from typing import List, Optional, Tuple
6+
7+
import pytest
8+
from transformers import AutoModelForSeq2SeqLM
9+
10+
from vllm.sequence import SampleLogprobs
11+
from vllm.utils import is_cpu
12+
13+
from ..conftest import DecoderPromptType
14+
from ..models.utils import check_logprobs_close
15+
16+
17+
def vllm_to_hf_output(
18+
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
19+
decoder_prompt_type: DecoderPromptType,
20+
):
21+
"""Sanitize vllm output to be comparable with hf output."""
22+
output_ids, output_str, out_logprobs = vllm_output
23+
24+
hf_output_str = output_str + "</s>"
25+
if decoder_prompt_type == DecoderPromptType.NONE:
26+
hf_output_str = "<s>" + hf_output_str
27+
28+
return output_ids, hf_output_str, out_logprobs
29+
30+
31+
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
32+
@pytest.mark.parametrize("dtype", ["bfloat16"])
33+
@pytest.mark.parametrize("max_tokens", [128])
34+
@pytest.mark.parametrize("num_logprobs", [5])
35+
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
36+
@pytest.mark.parametrize("enforce_eager", [True, False])
37+
@pytest.mark.skipif(
38+
is_cpu(),
39+
reason="CPU backend is not currently supported with encoder/decoder models"
40+
)
41+
def test_encoder_decoder_e2e(
42+
hf_runner,
43+
vllm_runner,
44+
example_encoder_decoder_prompts,
45+
model: str,
46+
dtype: str,
47+
max_tokens: int,
48+
num_logprobs: int,
49+
decoder_prompt_type: DecoderPromptType,
50+
enforce_eager: bool,
51+
) -> None:
52+
'''
53+
End-to-End (E2E) test for the encoder-decoder framework.
54+
This test evaluates the encoder-decoder functionality using the BART
55+
model. We compare the outputs of the Hugging Face and vLLM
56+
implementations to ensure that both implementations produce consistent
57+
and correct results.
58+
'''
59+
test_case_prompts = example_encoder_decoder_prompts[decoder_prompt_type]
60+
61+
# Configuration settings for HF baseline
62+
hf_kwargs = {
63+
"top_k": None,
64+
"num_beams": 1,
65+
"repetition_penalty": 1.0,
66+
"top_p": 1.0,
67+
"length_penalty": 1.0,
68+
"early_stopping": False,
69+
"no_repeat_ngram_size": None,
70+
"min_length": 0
71+
}
72+
73+
with hf_runner(model, dtype=dtype,
74+
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
75+
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
76+
test_case_prompts,
77+
max_tokens,
78+
num_logprobs,
79+
**hf_kwargs,
80+
))
81+
with vllm_runner(model, dtype=dtype,
82+
enforce_eager=enforce_eager) as vllm_model:
83+
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
84+
test_case_prompts, max_tokens, num_logprobs)
85+
86+
hf_skip_tokens = (1
87+
if decoder_prompt_type == DecoderPromptType.NONE else 0)
88+
89+
check_logprobs_close(
90+
outputs_0_lst=hf_outputs,
91+
outputs_1_lst=[
92+
vllm_to_hf_output(vllm_output, decoder_prompt_type)
93+
for vllm_output in vllm_outputs
94+
],
95+
name_0="hf",
96+
name_1="vllm",
97+
num_outputs_0_skip_tokens=hf_skip_tokens,
98+
)

tests/worker/test_encoder_decoder_model_runner.py

+160-22
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from array import array
23
from typing import List
34

@@ -7,13 +8,9 @@
78
from vllm.engine.arg_utils import EngineArgs
89
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
910
SequenceData, SequenceGroupMetadata)
10-
from vllm.utils import is_cpu
11+
from vllm.utils import is_cpu, make_tensor_with_pad
1112
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
12-
13-
# CUDA graph scenarios to test
14-
#
15-
# Currently CUDA graph is not supported
16-
ENFORCE_EAGER = [True]
13+
from vllm.worker.model_runner import _get_graph_batch_size
1714

1815
BATCH_SIZES = [1, 4, 16, 64, 256]
1916

@@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args,
4037
reason="CPU backend is currently "
4138
"unsupported for encoder/ "
4239
"decoder models")
43-
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
44-
def test_empty_seq_group(enforce_eager, ):
40+
def test_empty_seq_group():
4541
"""Verify prepare prompt and decode returns empty output
4642
for empty seq group list"""
4743

@@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ):
5248
max_num_batched_tokens=100000,
5349
max_num_seqs=100000,
5450
enable_chunked_prefill=False,
55-
enforce_eager=enforce_eager,
51+
enforce_eager=True,
5652
)
5753
seq_group_metadata_list: List[SequenceGroupMetadata] = []
5854
model_input = model_runner._prepare_model_input_tensors(
@@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ):
8581
"unsupported for encoder/ "
8682
"decoder models")
8783
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
88-
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
89-
def test_prepare_prompt(
90-
batch_size,
91-
enforce_eager,
92-
):
84+
def test_prepare_prompt(batch_size):
9385
'''
9486
Test the ability of the encoder/decoder model runner subclass to
9587
produce prefill-phase model inputs & attention metadata.
@@ -115,7 +107,7 @@ def test_prepare_prompt(
115107
max_num_batched_tokens=100000,
116108
max_num_seqs=100000,
117109
enable_chunked_prefill=False,
118-
enforce_eager=enforce_eager,
110+
enforce_eager=True,
119111
)
120112

121113
seq_lens: List[int] = []
@@ -281,11 +273,7 @@ def test_prepare_prompt(
281273
"unsupported for encoder/ "
282274
"decoder models")
283275
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
284-
@pytest.mark.parametrize("enforce_eager", ENFORCE_EAGER)
285-
def test_prepare_decode(
286-
batch_size,
287-
enforce_eager,
288-
):
276+
def test_prepare_decode(batch_size):
289277
'''
290278
Test the ability of the encoder/decoder model runner subclass to
291279
produce decode-phase model inputs & attention metadata.
@@ -311,7 +299,7 @@ def test_prepare_decode(
311299
max_num_batched_tokens=100000,
312300
max_num_seqs=100000,
313301
enable_chunked_prefill=False,
314-
enforce_eager=enforce_eager,
302+
enforce_eager=True,
315303
)
316304

317305
seq_lens: List[int] = []
@@ -428,7 +416,8 @@ def test_prepare_decode(
428416
expected,
429417
)
430418

431-
# Cuda graph should is currently not supported for encoder/decoer.
419+
# Model runner's CUDAGraph setting should be propagated to attention
420+
# metadata.
432421
assert attn_metadata.use_cuda_graph is False
433422

434423
# Verify the lengths of input tokens & positions
@@ -484,3 +473,152 @@ def test_prepare_decode(
484473
dtype=actual.dtype,
485474
)
486475
assert torch.equal(actual, expected)
476+
477+
478+
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
479+
def test_prepare_decode_cuda_graph(batch_size):
480+
"""
481+
Tests that for encoder-decoder models with CUDA Graph capture and replay
482+
enabled, the tensors used during the decode phase are correctly padded
483+
for varying input batch sizes.
484+
"""
485+
model_runner = _create_model_runner(
486+
"facebook/bart-base",
487+
seed=0,
488+
dtype="float16",
489+
max_num_batched_tokens=100000,
490+
max_num_seqs=100000,
491+
enable_chunked_prefill=False,
492+
enforce_eager=False,
493+
)
494+
495+
seq_lens: List[int] = []
496+
encoder_seq_lens: List[int] = []
497+
seq_group_metadata_list: List[SequenceGroupMetadata] = []
498+
block_tables = {0: [1]}
499+
cross_block_table = [2]
500+
for i in range(batch_size):
501+
# make sure all tokens fit into one block
502+
seq_len = i % (model_runner.block_size - 1) + 1
503+
seq_lens.append(seq_len)
504+
seq_data = SequenceData(
505+
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
506+
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
507+
encoder_seq_lens.append(encoder_seq_len)
508+
encoder_seq_data = SequenceData(
509+
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
510+
seq_group_metadata = SequenceGroupMetadata(
511+
request_id=f"test_{i}",
512+
is_prompt=False,
513+
seq_data={0: seq_data},
514+
sampling_params=SamplingParams(temperature=0),
515+
block_tables=block_tables,
516+
encoder_seq_data=encoder_seq_data,
517+
cross_block_table=cross_block_table,
518+
)
519+
assert seq_group_metadata.token_chunk_size == 1
520+
seq_group_metadata_list.append(seq_group_metadata)
521+
522+
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
523+
input_tokens = model_input.input_tokens
524+
input_positions = model_input.input_positions
525+
attn_metadata = model_input.attn_metadata
526+
return_seq_lens = model_input.seq_lens
527+
slot_mapping = attn_metadata.slot_mapping
528+
encoder_input_tokens = model_input.encoder_input_tokens
529+
encoder_input_positions = model_input.encoder_input_positions
530+
cross_slot_mapping = attn_metadata.cross_slot_mapping
531+
532+
# With CUDA Graph capture and replay enabled, the decoder and encoder
533+
# input sequences will be padded. Create the expected padded tensors
534+
# accordingly.
535+
graph_batch_size = _get_graph_batch_size(batch_size)
536+
cuda_graph_pad_size = graph_batch_size - batch_size
537+
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
538+
padded_encoder_seq_lens = encoder_seq_lens + list(
539+
itertools.repeat(1, cuda_graph_pad_size))
540+
541+
assert return_seq_lens == padded_seq_lens
542+
assert len(slot_mapping) == len(input_tokens)
543+
assert len(cross_slot_mapping) == len(encoder_input_tokens)
544+
545+
# Verify attention metadata
546+
device = model_runner.device
547+
assert attn_metadata.num_prefills == 0
548+
assert attn_metadata.num_decode_tokens > 0
549+
assert torch.equal(
550+
attn_metadata.seq_lens_tensor,
551+
torch.tensor(padded_seq_lens, device=device, dtype=torch.int))
552+
assert attn_metadata.seq_lens == padded_seq_lens
553+
assert attn_metadata.max_prefill_seq_len == 0
554+
assert attn_metadata.max_decode_seq_len == max(seq_lens)
555+
# - Encoder attention metadata
556+
assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
557+
assert torch.equal(
558+
attn_metadata.encoder_seq_lens_tensor,
559+
torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int))
560+
assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
561+
assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
562+
563+
# Verify block tables are correct for prompts
564+
# - Decoder self-attention. Pad the block tables as expected.
565+
expected = [block_tables[0] for _ in range(batch_size)]
566+
expected.extend([[] for _ in range(cuda_graph_pad_size)])
567+
expected = make_tensor_with_pad(
568+
expected,
569+
max_len=64,
570+
pad=0,
571+
dtype=torch.int32,
572+
device=model_runner.device,
573+
)
574+
assert torch.equal(
575+
attn_metadata.block_tables,
576+
expected,
577+
)
578+
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
579+
# as expected.
580+
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))]
581+
expected.extend([[] for _ in range(cuda_graph_pad_size)])
582+
expected = make_tensor_with_pad(
583+
expected,
584+
max_len=64,
585+
pad=0,
586+
dtype=torch.int32,
587+
device=model_runner.device,
588+
)
589+
assert torch.equal(
590+
attn_metadata.cross_block_tables,
591+
expected,
592+
)
593+
594+
# Model runner's CUDAGraph setting should be propagated to attention
595+
# metadata.
596+
assert attn_metadata.use_cuda_graph is True
597+
598+
# Verify the lengths of input tokens & positions
599+
# - Decoder
600+
assert len(input_tokens) == len(padded_seq_lens)
601+
assert len(input_positions) == len(padded_seq_lens)
602+
# -- An indirect check that model_input.input_tokens
603+
# and model_input.input_positions are correct -
604+
# by design of the test, the input tokens are
605+
# equal to the input position values, so if
606+
# the model_input data structure has the correct
607+
# values then these two should be equal
608+
assert torch.equal(
609+
input_tokens,
610+
input_positions,
611+
)
612+
# - Encoder
613+
assert len(encoder_input_tokens) == 0
614+
assert len(encoder_input_tokens) == 0
615+
# -- An indirect check that model_input.encoder_input_tokens
616+
# and model_input.encoder_input_positions are correct -
617+
# by design of the test, the input tokens are
618+
# equal to the input position values, so if
619+
# the model_input data structure has the correct
620+
# values then these two should be equal
621+
assert torch.equal(
622+
encoder_input_tokens,
623+
encoder_input_positions,
624+
)

vllm/attention/backends/abstract.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,27 @@ def graph_clone(self, batch_size: int) -> "AttentionState[T]":
156156
...
157157

158158
@abstractmethod
159-
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
159+
def graph_capture_get_metadata_for_batch(
160+
self,
161+
batch_size: int,
162+
is_encoder_decoder_model: bool = False) -> T:
160163
"""Get attention metadata for CUDA graph capture of batch_size."""
161164
...
162165

163166
@abstractmethod
164-
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
167+
def get_graph_input_buffers(
168+
self,
169+
attn_metadata: T,
170+
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
165171
"""Get attention-specific input buffers for CUDA graph capture."""
166172
...
167173

168174
@abstractmethod
169-
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
170-
attn_metadata: T) -> None:
175+
def prepare_graph_input_buffers(
176+
self,
177+
input_buffers: Dict[str, Any],
178+
attn_metadata: T,
179+
is_encoder_decoder_model: bool = False) -> None:
171180
"""In-place modify input buffers dict for CUDA graph replay."""
172181
...
173182

0 commit comments

Comments
 (0)