1
+ import itertools
1
2
from array import array
2
3
from typing import List
3
4
7
8
from vllm .engine .arg_utils import EngineArgs
8
9
from vllm .sequence import (VLLM_TOKEN_ID_ARRAY_TYPE , SamplingParams ,
9
10
SequenceData , SequenceGroupMetadata )
10
- from vllm .utils import is_cpu
11
+ from vllm .utils import is_cpu , make_tensor_with_pad
11
12
from vllm .worker .enc_dec_model_runner import EncoderDecoderModelRunner
12
-
13
- # CUDA graph scenarios to test
14
- #
15
- # Currently CUDA graph is not supported
16
- ENFORCE_EAGER = [True ]
13
+ from vllm .worker .model_runner import _get_graph_batch_size
17
14
18
15
BATCH_SIZES = [1 , 4 , 16 , 64 , 256 ]
19
16
@@ -40,8 +37,7 @@ def _create_model_runner(model: str, *args,
40
37
reason = "CPU backend is currently "
41
38
"unsupported for encoder/ "
42
39
"decoder models" )
43
- @pytest .mark .parametrize ("enforce_eager" , ENFORCE_EAGER )
44
- def test_empty_seq_group (enforce_eager , ):
40
+ def test_empty_seq_group ():
45
41
"""Verify prepare prompt and decode returns empty output
46
42
for empty seq group list"""
47
43
@@ -52,7 +48,7 @@ def test_empty_seq_group(enforce_eager, ):
52
48
max_num_batched_tokens = 100000 ,
53
49
max_num_seqs = 100000 ,
54
50
enable_chunked_prefill = False ,
55
- enforce_eager = enforce_eager ,
51
+ enforce_eager = True ,
56
52
)
57
53
seq_group_metadata_list : List [SequenceGroupMetadata ] = []
58
54
model_input = model_runner ._prepare_model_input_tensors (
@@ -85,11 +81,7 @@ def test_empty_seq_group(enforce_eager, ):
85
81
"unsupported for encoder/ "
86
82
"decoder models" )
87
83
@pytest .mark .parametrize ("batch_size" , BATCH_SIZES )
88
- @pytest .mark .parametrize ("enforce_eager" , ENFORCE_EAGER )
89
- def test_prepare_prompt (
90
- batch_size ,
91
- enforce_eager ,
92
- ):
84
+ def test_prepare_prompt (batch_size ):
93
85
'''
94
86
Test the ability of the encoder/decoder model runner subclass to
95
87
produce prefill-phase model inputs & attention metadata.
@@ -115,7 +107,7 @@ def test_prepare_prompt(
115
107
max_num_batched_tokens = 100000 ,
116
108
max_num_seqs = 100000 ,
117
109
enable_chunked_prefill = False ,
118
- enforce_eager = enforce_eager ,
110
+ enforce_eager = True ,
119
111
)
120
112
121
113
seq_lens : List [int ] = []
@@ -281,11 +273,7 @@ def test_prepare_prompt(
281
273
"unsupported for encoder/ "
282
274
"decoder models" )
283
275
@pytest .mark .parametrize ("batch_size" , BATCH_SIZES )
284
- @pytest .mark .parametrize ("enforce_eager" , ENFORCE_EAGER )
285
- def test_prepare_decode (
286
- batch_size ,
287
- enforce_eager ,
288
- ):
276
+ def test_prepare_decode (batch_size ):
289
277
'''
290
278
Test the ability of the encoder/decoder model runner subclass to
291
279
produce decode-phase model inputs & attention metadata.
@@ -311,7 +299,7 @@ def test_prepare_decode(
311
299
max_num_batched_tokens = 100000 ,
312
300
max_num_seqs = 100000 ,
313
301
enable_chunked_prefill = False ,
314
- enforce_eager = enforce_eager ,
302
+ enforce_eager = True ,
315
303
)
316
304
317
305
seq_lens : List [int ] = []
@@ -428,7 +416,8 @@ def test_prepare_decode(
428
416
expected ,
429
417
)
430
418
431
- # Cuda graph should is currently not supported for encoder/decoer.
419
+ # Model runner's CUDAGraph setting should be propagated to attention
420
+ # metadata.
432
421
assert attn_metadata .use_cuda_graph is False
433
422
434
423
# Verify the lengths of input tokens & positions
@@ -484,3 +473,152 @@ def test_prepare_decode(
484
473
dtype = actual .dtype ,
485
474
)
486
475
assert torch .equal (actual , expected )
476
+
477
+
478
+ @pytest .mark .parametrize ("batch_size" , list (range (1 , 257 )))
479
+ def test_prepare_decode_cuda_graph (batch_size ):
480
+ """
481
+ Tests that for encoder-decoder models with CUDA Graph capture and replay
482
+ enabled, the tensors used during the decode phase are correctly padded
483
+ for varying input batch sizes.
484
+ """
485
+ model_runner = _create_model_runner (
486
+ "facebook/bart-base" ,
487
+ seed = 0 ,
488
+ dtype = "float16" ,
489
+ max_num_batched_tokens = 100000 ,
490
+ max_num_seqs = 100000 ,
491
+ enable_chunked_prefill = False ,
492
+ enforce_eager = False ,
493
+ )
494
+
495
+ seq_lens : List [int ] = []
496
+ encoder_seq_lens : List [int ] = []
497
+ seq_group_metadata_list : List [SequenceGroupMetadata ] = []
498
+ block_tables = {0 : [1 ]}
499
+ cross_block_table = [2 ]
500
+ for i in range (batch_size ):
501
+ # make sure all tokens fit into one block
502
+ seq_len = i % (model_runner .block_size - 1 ) + 1
503
+ seq_lens .append (seq_len )
504
+ seq_data = SequenceData (
505
+ array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (seq_len ))))
506
+ encoder_seq_len = (i + 1 ) % (model_runner .block_size - 1 ) + 1
507
+ encoder_seq_lens .append (encoder_seq_len )
508
+ encoder_seq_data = SequenceData (
509
+ array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (encoder_seq_len ))))
510
+ seq_group_metadata = SequenceGroupMetadata (
511
+ request_id = f"test_{ i } " ,
512
+ is_prompt = False ,
513
+ seq_data = {0 : seq_data },
514
+ sampling_params = SamplingParams (temperature = 0 ),
515
+ block_tables = block_tables ,
516
+ encoder_seq_data = encoder_seq_data ,
517
+ cross_block_table = cross_block_table ,
518
+ )
519
+ assert seq_group_metadata .token_chunk_size == 1
520
+ seq_group_metadata_list .append (seq_group_metadata )
521
+
522
+ model_input = model_runner .prepare_model_input (seq_group_metadata_list )
523
+ input_tokens = model_input .input_tokens
524
+ input_positions = model_input .input_positions
525
+ attn_metadata = model_input .attn_metadata
526
+ return_seq_lens = model_input .seq_lens
527
+ slot_mapping = attn_metadata .slot_mapping
528
+ encoder_input_tokens = model_input .encoder_input_tokens
529
+ encoder_input_positions = model_input .encoder_input_positions
530
+ cross_slot_mapping = attn_metadata .cross_slot_mapping
531
+
532
+ # With CUDA Graph capture and replay enabled, the decoder and encoder
533
+ # input sequences will be padded. Create the expected padded tensors
534
+ # accordingly.
535
+ graph_batch_size = _get_graph_batch_size (batch_size )
536
+ cuda_graph_pad_size = graph_batch_size - batch_size
537
+ padded_seq_lens = seq_lens + list (itertools .repeat (1 , cuda_graph_pad_size ))
538
+ padded_encoder_seq_lens = encoder_seq_lens + list (
539
+ itertools .repeat (1 , cuda_graph_pad_size ))
540
+
541
+ assert return_seq_lens == padded_seq_lens
542
+ assert len (slot_mapping ) == len (input_tokens )
543
+ assert len (cross_slot_mapping ) == len (encoder_input_tokens )
544
+
545
+ # Verify attention metadata
546
+ device = model_runner .device
547
+ assert attn_metadata .num_prefills == 0
548
+ assert attn_metadata .num_decode_tokens > 0
549
+ assert torch .equal (
550
+ attn_metadata .seq_lens_tensor ,
551
+ torch .tensor (padded_seq_lens , device = device , dtype = torch .int ))
552
+ assert attn_metadata .seq_lens == padded_seq_lens
553
+ assert attn_metadata .max_prefill_seq_len == 0
554
+ assert attn_metadata .max_decode_seq_len == max (seq_lens )
555
+ # - Encoder attention metadata
556
+ assert attn_metadata .encoder_seq_lens == padded_encoder_seq_lens
557
+ assert torch .equal (
558
+ attn_metadata .encoder_seq_lens_tensor ,
559
+ torch .tensor (padded_encoder_seq_lens , device = device , dtype = torch .int ))
560
+ assert attn_metadata .max_encoder_seq_len == max (padded_encoder_seq_lens )
561
+ assert attn_metadata .num_encoder_tokens == sum (padded_encoder_seq_lens )
562
+
563
+ # Verify block tables are correct for prompts
564
+ # - Decoder self-attention. Pad the block tables as expected.
565
+ expected = [block_tables [0 ] for _ in range (batch_size )]
566
+ expected .extend ([[] for _ in range (cuda_graph_pad_size )])
567
+ expected = make_tensor_with_pad (
568
+ expected ,
569
+ max_len = 64 ,
570
+ pad = 0 ,
571
+ dtype = torch .int32 ,
572
+ device = model_runner .device ,
573
+ )
574
+ assert torch .equal (
575
+ attn_metadata .block_tables ,
576
+ expected ,
577
+ )
578
+ # - Encoder/decoder cross-attention. Pad the cross-attention block tables
579
+ # as expected.
580
+ expected = [cross_block_table for _ in range (len (seq_group_metadata_list ))]
581
+ expected .extend ([[] for _ in range (cuda_graph_pad_size )])
582
+ expected = make_tensor_with_pad (
583
+ expected ,
584
+ max_len = 64 ,
585
+ pad = 0 ,
586
+ dtype = torch .int32 ,
587
+ device = model_runner .device ,
588
+ )
589
+ assert torch .equal (
590
+ attn_metadata .cross_block_tables ,
591
+ expected ,
592
+ )
593
+
594
+ # Model runner's CUDAGraph setting should be propagated to attention
595
+ # metadata.
596
+ assert attn_metadata .use_cuda_graph is True
597
+
598
+ # Verify the lengths of input tokens & positions
599
+ # - Decoder
600
+ assert len (input_tokens ) == len (padded_seq_lens )
601
+ assert len (input_positions ) == len (padded_seq_lens )
602
+ # -- An indirect check that model_input.input_tokens
603
+ # and model_input.input_positions are correct -
604
+ # by design of the test, the input tokens are
605
+ # equal to the input position values, so if
606
+ # the model_input data structure has the correct
607
+ # values then these two should be equal
608
+ assert torch .equal (
609
+ input_tokens ,
610
+ input_positions ,
611
+ )
612
+ # - Encoder
613
+ assert len (encoder_input_tokens ) == 0
614
+ assert len (encoder_input_tokens ) == 0
615
+ # -- An indirect check that model_input.encoder_input_tokens
616
+ # and model_input.encoder_input_positions are correct -
617
+ # by design of the test, the input tokens are
618
+ # equal to the input position values, so if
619
+ # the model_input data structure has the correct
620
+ # values then these two should be equal
621
+ assert torch .equal (
622
+ encoder_input_tokens ,
623
+ encoder_input_positions ,
624
+ )
0 commit comments