Skip to content

Commit f65b904

Browse files
committed
fix tests in v1/worker
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent e412e02 commit f65b904

File tree

2 files changed

+74
-22
lines changed

2 files changed

+74
-22
lines changed

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
from vllm.sampling_params import SamplingParams
1111
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
12+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
13+
KVCacheGroupSpec, KVCacheTensor)
1214
from vllm.v1.sample.metadata import SamplingMetadata
13-
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
14-
InputBatch)
15+
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
16+
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
1517

1618
VOCAB_SIZE = 1024
1719
NUM_OUTPUT_TOKENS = 20
@@ -22,6 +24,27 @@
2224
MAX_NUM_PROMPT_TOKENS = 64
2325

2426

27+
def get_kv_cache_config() -> KVCacheConfig:
28+
return KVCacheConfig(
29+
num_blocks=10,
30+
tensors={
31+
"layer.0": KVCacheTensor(size=1024),
32+
},
33+
kv_cache_groups=[
34+
KVCacheGroupSpec(
35+
layer_names=["layer.0"],
36+
kv_cache_spec=FullAttentionSpec(
37+
block_size=1,
38+
num_kv_heads=1,
39+
head_size=16,
40+
dtype=torch.float16,
41+
use_mla=False,
42+
),
43+
),
44+
],
45+
)
46+
47+
2548
def _compare_objs(obj1, obj2):
2649
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
2750
attr_names = set([
@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
4164
elif isinstance(a, np.ndarray):
4265
if np.allclose(a, b):
4366
is_same = True
67+
elif isinstance(a, MultiGroupBlockTable):
68+
for a_i, b_i in zip(a.block_tables, b.block_tables):
69+
_compare_objs(a_i, b_i)
70+
is_same = True
4471
elif isinstance(a, (BlockTable, SamplingMetadata)):
4572
_compare_objs(a, b)
4673
is_same = True # if we make it here must be same
@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
198225
sampling_params=_create_sampling_params(),
199226
mm_inputs=[],
200227
mm_positions=[],
201-
block_ids=[],
228+
block_ids=[[]],
202229
generator=None,
203230
num_computed_tokens=len(output_token_ids),
204231
output_token_ids=output_token_ids,
@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
220247
input_batch: InputBatch = InputBatch(
221248
max_num_reqs=batch_size,
222249
max_model_len=1024,
223-
max_num_blocks_per_req=10,
224250
max_num_batched_tokens=1024,
225251
device=torch.device(device),
226252
pin_memory=is_pin_memory_available(),
227253
vocab_size=1024,
254+
kv_cache_config=get_kv_cache_config(),
228255
)
229256
reqs: list[CachedRequestState] = []
230257
req_id_reqs = {}
@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
310337
input_batch: InputBatch = InputBatch(
311338
max_num_reqs=batch_size,
312339
max_model_len=1024,
313-
max_num_blocks_per_req=10,
314340
max_num_batched_tokens=1024,
315341
device=torch.device(device),
316342
pin_memory=is_pin_memory_available(),
317343
vocab_size=1024,
344+
kv_cache_config=get_kv_cache_config(),
318345
)
319346
ref_input_batch: InputBatch = InputBatch(
320347
max_num_reqs=batch_size,
321348
max_model_len=1024,
322-
max_num_blocks_per_req=10,
323349
max_num_batched_tokens=1024,
324350
device=torch.device(device),
325351
pin_memory=is_pin_memory_available(),
326352
vocab_size=1024,
353+
kv_cache_config=get_kv_cache_config(),
327354
)
328355

329356
reqs: list[CachedRequestState] = []

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,51 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import weakref
32

43
import pytest
5-
import torch
64

7-
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
5+
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
6+
SchedulerConfig, VllmConfig)
87
from vllm.sampling_params import SamplingParams
98
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
109
SchedulerOutput)
11-
from vllm.v1.kv_cache_interface import FullAttentionSpec
10+
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
11+
KVCacheGroupSpec, KVCacheTensor)
1212
from vllm.v1.sample.metadata import SamplingMetadata
13+
from vllm.v1.worker.gpu_input_batch import InputBatch
1314
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1415

1516

1617
def initialize_kv_cache(runner: GPUModelRunner):
1718
"""
1819
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
1920
"""
20-
kv_cache_spec = FullAttentionSpec(block_size=16,
21-
num_kv_heads=1,
22-
head_size=64,
23-
dtype=torch.float16,
24-
use_mla=False)
25-
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
26-
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
21+
kv_cache_config = KVCacheConfig(
22+
num_blocks=10,
23+
tensors={
24+
"layer.0": KVCacheTensor(size=1024),
25+
},
26+
kv_cache_groups=[
27+
KVCacheGroupSpec(
28+
layer_names=["layer.0"],
29+
kv_cache_spec=FullAttentionSpec(
30+
block_size=16,
31+
num_kv_heads=runner.model_config.get_num_kv_heads(
32+
runner.parallel_config),
33+
head_size=runner.model_config.get_head_size(),
34+
dtype=runner.kv_cache_dtype,
35+
use_mla=False,
36+
))
37+
])
38+
runner.kv_cache_config = kv_cache_config
39+
runner.input_batch = InputBatch(
40+
max_num_reqs=runner.max_num_reqs,
41+
max_model_len=runner.max_model_len,
42+
max_num_batched_tokens=runner.max_num_tokens,
43+
device=runner.device,
44+
pin_memory=runner.pin_memory,
45+
vocab_size=runner.model_config.get_vocab_size(),
46+
kv_cache_config=kv_cache_config,
47+
)
48+
runner.initialize_attn_backend(kv_cache_config)
2749

2850

2951
@pytest.fixture
@@ -48,10 +70,12 @@ def model_runner():
4870
swap_space=0,
4971
cache_dtype="auto",
5072
)
73+
parallel_config = ParallelConfig()
5174
vllm_config = VllmConfig(
5275
model_config=model_config,
5376
cache_config=cache_config,
5477
scheduler_config=scheduler_config,
78+
parallel_config=parallel_config,
5579
)
5680

5781
device = "cuda"
@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
7397
mm_hashes=[],
7498
mm_positions=[],
7599
sampling_params=SamplingParams(),
76-
block_ids=[0],
100+
block_ids=[[0]],
77101
num_computed_tokens=0,
78102
lora_request=None,
79103
))
@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
111135

112136
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
113137
req_index = model_runner.input_batch.req_id_to_index[req_id]
114-
block_table = model_runner.input_batch.block_table
138+
block_table = model_runner.input_batch.block_table[0]
115139
req_state = model_runner.requests[req_id]
116-
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
140+
if block_table.num_blocks_per_row[req_index] != len(
141+
req_state.block_ids[0]):
117142
return False
118143
num_blocks = block_table.num_blocks_per_row[req_index]
119144
return (block_table.block_table_np[req_index, :num_blocks] ==
120-
req_state.block_ids).all()
145+
req_state.block_ids[0]).all()
121146

122147

123148
def test_update_states_new_request(model_runner):
@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
200225
req_id=req_id,
201226
resumed_from_preemption=False,
202227
new_token_ids=[],
203-
new_block_ids=[],
228+
new_block_ids=[[]],
204229
num_computed_tokens=0,
205230
)
206231

0 commit comments

Comments
 (0)