1
1
# SPDX-License-Identifier: Apache-2.0
2
- import weakref
3
2
4
3
import pytest
5
- import torch
6
4
7
- from vllm .config import CacheConfig , ModelConfig , SchedulerConfig , VllmConfig
5
+ from vllm .config import (CacheConfig , ModelConfig , ParallelConfig ,
6
+ SchedulerConfig , VllmConfig )
8
7
from vllm .sampling_params import SamplingParams
9
8
from vllm .v1 .core .sched .output import (CachedRequestData , NewRequestData ,
10
9
SchedulerOutput )
11
- from vllm .v1 .kv_cache_interface import FullAttentionSpec
10
+ from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
11
+ KVCacheGroupSpec , KVCacheTensor )
12
12
from vllm .v1 .sample .metadata import SamplingMetadata
13
+ from vllm .v1 .worker .gpu_input_batch import InputBatch
13
14
from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
14
15
15
16
16
17
def initialize_kv_cache (runner : GPUModelRunner ):
17
18
"""
18
19
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
19
20
"""
20
- kv_cache_spec = FullAttentionSpec (block_size = 16 ,
21
- num_kv_heads = 1 ,
22
- head_size = 64 ,
23
- dtype = torch .float16 ,
24
- use_mla = False )
25
- runner .attn_metadata_builder = runner .attn_backend .get_builder_cls ()(
26
- weakref .proxy (runner ), kv_cache_spec , runner .input_batch .block_table )
21
+ kv_cache_config = KVCacheConfig (
22
+ num_blocks = 10 ,
23
+ tensors = {
24
+ "layer.0" : KVCacheTensor (size = 1024 ),
25
+ },
26
+ kv_cache_groups = [
27
+ KVCacheGroupSpec (
28
+ layer_names = ["layer.0" ],
29
+ kv_cache_spec = FullAttentionSpec (
30
+ block_size = 16 ,
31
+ num_kv_heads = runner .model_config .get_num_kv_heads (
32
+ runner .parallel_config ),
33
+ head_size = runner .model_config .get_head_size (),
34
+ dtype = runner .kv_cache_dtype ,
35
+ use_mla = False ,
36
+ ))
37
+ ])
38
+ runner .kv_cache_config = kv_cache_config
39
+ runner .input_batch = InputBatch (
40
+ max_num_reqs = runner .max_num_reqs ,
41
+ max_model_len = runner .max_model_len ,
42
+ max_num_batched_tokens = runner .max_num_tokens ,
43
+ device = runner .device ,
44
+ pin_memory = runner .pin_memory ,
45
+ vocab_size = runner .model_config .get_vocab_size (),
46
+ kv_cache_config = kv_cache_config ,
47
+ )
48
+ runner .initialize_attn_backend (kv_cache_config )
27
49
28
50
29
51
@pytest .fixture
@@ -48,10 +70,12 @@ def model_runner():
48
70
swap_space = 0 ,
49
71
cache_dtype = "auto" ,
50
72
)
73
+ parallel_config = ParallelConfig ()
51
74
vllm_config = VllmConfig (
52
75
model_config = model_config ,
53
76
cache_config = cache_config ,
54
77
scheduler_config = scheduler_config ,
78
+ parallel_config = parallel_config ,
55
79
)
56
80
57
81
device = "cuda"
@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
73
97
mm_hashes = [],
74
98
mm_positions = [],
75
99
sampling_params = SamplingParams (),
76
- block_ids = [0 ],
100
+ block_ids = [[ 0 ] ],
77
101
num_computed_tokens = 0 ,
78
102
lora_request = None ,
79
103
))
@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
111
135
112
136
def _is_req_state_block_table_match (model_runner , req_id : str ) -> bool :
113
137
req_index = model_runner .input_batch .req_id_to_index [req_id ]
114
- block_table = model_runner .input_batch .block_table
138
+ block_table = model_runner .input_batch .block_table [ 0 ]
115
139
req_state = model_runner .requests [req_id ]
116
- if block_table .num_blocks_per_row [req_index ] != len (req_state .block_ids ):
140
+ if block_table .num_blocks_per_row [req_index ] != len (
141
+ req_state .block_ids [0 ]):
117
142
return False
118
143
num_blocks = block_table .num_blocks_per_row [req_index ]
119
144
return (block_table .block_table_np [req_index , :num_blocks ] ==
120
- req_state .block_ids ).all ()
145
+ req_state .block_ids [ 0 ] ).all ()
121
146
122
147
123
148
def test_update_states_new_request (model_runner ):
@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
200
225
req_id = req_id ,
201
226
resumed_from_preemption = False ,
202
227
new_token_ids = [],
203
- new_block_ids = [],
228
+ new_block_ids = [[] ],
204
229
num_computed_tokens = 0 ,
205
230
)
206
231
0 commit comments