Skip to content

Commit ee3d1ed

Browse files
heheda12345mawong-amd
authored andcommitted
[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (vllm-project#17483)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent 0e48df2 commit ee3d1ed

File tree

11 files changed

+132
-68
lines changed

11 files changed

+132
-68
lines changed

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
221221
max_num_reqs=batch_size,
222222
max_model_len=1024,
223223
max_num_blocks_per_req=10,
224+
max_num_batched_tokens=1024,
224225
device=torch.device(device),
225226
pin_memory=is_pin_memory_available(),
226227
vocab_size=1024,
@@ -310,6 +311,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
310311
max_num_reqs=batch_size,
311312
max_model_len=1024,
312313
max_num_blocks_per_req=10,
314+
max_num_batched_tokens=1024,
313315
device=torch.device(device),
314316
pin_memory=is_pin_memory_available(),
315317
vocab_size=1024,
@@ -318,6 +320,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
318320
max_num_reqs=batch_size,
319321
max_model_len=1024,
320322
max_num_blocks_per_req=10,
323+
max_num_batched_tokens=1024,
321324
device=torch.device(device),
322325
pin_memory=is_pin_memory_available(),
323326
vocab_size=1024,

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,31 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import weakref
3+
24
import pytest
5+
import torch
36

47
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
58
from vllm.sampling_params import SamplingParams
69
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
710
SchedulerOutput)
11+
from vllm.v1.kv_cache_interface import FullAttentionSpec
812
from vllm.v1.sample.metadata import SamplingMetadata
913
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
1014

1115

16+
def initialize_kv_cache(runner: GPUModelRunner):
17+
"""
18+
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
19+
"""
20+
kv_cache_spec = FullAttentionSpec(block_size=16,
21+
num_kv_heads=1,
22+
head_size=64,
23+
dtype=torch.float16,
24+
use_mla=False)
25+
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
26+
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
27+
28+
1229
@pytest.fixture
1330
def model_runner():
1431
scheduler_config = SchedulerConfig(
@@ -38,7 +55,9 @@ def model_runner():
3855
)
3956

4057
device = "cuda"
41-
return GPUModelRunner(vllm_config, device)
58+
runner = GPUModelRunner(vllm_config, device)
59+
initialize_kv_cache(runner)
60+
return runner
4261

4362

4463
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:

vllm/v1/attention/backends/flash_attn.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from vllm.platforms import current_platform
2020
from vllm.utils import cdiv
2121
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
22+
from vllm.v1.kv_cache_interface import AttentionSpec
23+
from vllm.v1.worker.block_table import BlockTable
2224

2325
if TYPE_CHECKING:
2426
from vllm.v1.core.sched.output import SchedulerOutput
@@ -167,7 +169,7 @@ def make_local_attention_virtual_batches(
167169
query_start_loc_np: np.ndarray,
168170
seq_lens_np: np.ndarray,
169171
block_table: torch.Tensor,
170-
page_size: int = 0,
172+
block_size: int = 0,
171173
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
172174
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
173175
actual_batch_size = seq_lens_np.shape[0]
@@ -238,14 +240,14 @@ def make_local_attention_virtual_batches(
238240
# For the example the local attention blocks start at:
239241
# _b0_ _____b1_____ _b2_
240242
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
241-
block_starts = k_seqstarts_absolute // page_size
242-
assert attn_chunk_size % page_size == 0, \
243+
block_starts = k_seqstarts_absolute // block_size
244+
assert attn_chunk_size % block_size == 0, \
243245
f"attn_chunk_size {attn_chunk_size} is not " \
244-
f"divisible by page_size {page_size}"
245-
pages_per_local_batch = attn_chunk_size // page_size
246+
f"divisible by block_size {block_size}"
247+
pages_per_local_batch = attn_chunk_size // block_size
246248

247249
# Create a block_table for the local attention blocks
248-
# For out example if we have a block-table like (assuming page_size=2):
250+
# For out example if we have a block-table like (assuming block_size=2):
249251
# block_table = [
250252
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
251253
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
@@ -289,7 +291,8 @@ def _get_sliding_window_configs(
289291

290292
class FlashAttentionMetadataBuilder:
291293

292-
def __init__(self, runner: "GPUModelRunner"):
294+
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
295+
block_table: BlockTable):
293296
model_config = runner.model_config
294297
compilation_config = runner.vllm_config.compilation_config
295298

@@ -299,7 +302,9 @@ def __init__(self, runner: "GPUModelRunner"):
299302
self.num_heads_kv = model_config.get_num_kv_heads(
300303
runner.parallel_config)
301304
self.headdim = model_config.get_head_size()
302-
self.page_size = self.runner.block_size
305+
self.block_size = kv_cache_spec.block_size
306+
self.kv_cache_spec = kv_cache_spec
307+
self.block_table = block_table
303308

304309
if get_flash_attn_version() == 3:
305310
self.aot_schedule = not compilation_config.full_cuda_graph
@@ -323,9 +328,17 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
323328
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
324329
query_start_loc = common_attn_metadata.query_start_loc
325330
seq_lens = common_attn_metadata.seq_lens
326-
block_table = (
327-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
328-
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
331+
block_table = self.block_table
332+
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
333+
334+
block_table.slot_mapping[:num_actual_tokens].copy_(
335+
block_table.slot_mapping_cpu[:num_actual_tokens],
336+
non_blocking=True)
337+
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
338+
# mode.
339+
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
340+
341+
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
329342

330343
if self.aot_sliding_window is None:
331344
self.aot_sliding_window = (-1, -1)
@@ -354,7 +367,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
354367
num_heads_q=self.num_heads_q,
355368
num_heads_kv=self.num_heads_kv,
356369
headdim=self.headdim,
357-
page_size=self.page_size,
370+
page_size=self.block_size,
358371
cu_seqlens_q=cu_query_lens,
359372
causal=causal,
360373
window_size=self.aot_sliding_window,
@@ -365,12 +378,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
365378
local_attn_metadata = None
366379
if self.runner.attention_chunk_size is not None:
367380
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
368-
virt_block_table = make_local_attention_virtual_batches(
381+
virt_block_table_tensor = make_local_attention_virtual_batches(
369382
self.runner.attention_chunk_size,
370383
self.runner.query_start_loc_np[:num_reqs + 1],
371384
self.runner.seq_lens_np[:num_reqs],
372-
block_table,
373-
self.runner.block_size,
385+
block_table_tensor,
386+
self.block_size,
374387
)
375388
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
376389
self.runner.device, non_blocking=True)
@@ -389,7 +402,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
389402
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
390403
local_query_start_loc=local_query_start_loc,
391404
local_seqused_k=local_seqused_k,
392-
local_block_table=virt_block_table,
405+
local_block_table=virt_block_table_tensor,
393406
local_max_query_len=local_max_query_len,
394407
local_max_seq_len=local_max_seq_len,
395408
local_scheduler_metadata=local_scheduler_metadata,
@@ -440,7 +453,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
440453
query_start_loc=query_start_loc,
441454
max_seq_len=max_seq_len,
442455
seq_lens=seq_lens,
443-
block_table=block_table,
456+
block_table=block_table_tensor,
444457
slot_mapping=slot_mapping,
445458
use_cascade=use_cascade,
446459
common_prefix_len=common_prefix_len,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from vllm.logger import init_logger
2020
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
2121
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
22+
from vllm.v1.kv_cache_interface import AttentionSpec
23+
from vllm.v1.worker.block_table import BlockTable
2224

2325
if TYPE_CHECKING:
2426
from vllm.v1.core.sched.output import SchedulerOutput
@@ -202,7 +204,8 @@ def __post_init__(self):
202204

203205
class FlashInferMetadataBuilder:
204206

205-
def __init__(self, runner: GPUModelRunner):
207+
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
208+
block_table: BlockTable):
206209
self.runner = runner
207210
self._workspace_buffer = None
208211
self._prefill_wrapper = None # Wrapper for prefill/append
@@ -213,6 +216,8 @@ def __init__(self, runner: GPUModelRunner):
213216
self.global_hyperparameters: Optional[PerLayerParameters] = None
214217

215218
self.vllm_config = get_current_vllm_config()
219+
self.kv_cache_spec = kv_cache_spec
220+
self.block_table = block_table
216221

217222
def reorder_batch(self, input_batch: InputBatch,
218223
scheduler_output: SchedulerOutput) -> bool:
@@ -400,13 +405,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
400405
assert self._num_decodes + self._num_prefills == num_reqs
401406
assert (self._num_decode_tokens +
402407
self._num_prefill_tokens == num_actual_tokens)
403-
page_size = self.runner.block_size
408+
page_size = self.kv_cache_spec.block_size
404409
device = self.runner.device
405410
qo_indptr = common_attn_metadata.query_start_loc
406411
seq_lens = common_attn_metadata.seq_lens
407-
block_table = (
408-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
409-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
412+
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
413+
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
410414
self.runner.device, non_blocking=True).long()
411415

412416
block_table_bounds = (seq_lens + page_size - 1) // page_size
@@ -422,24 +426,25 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
422426
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
423427
dtype=torch.int32,
424428
device=device)
425-
shared_kv_page_indices = block_table[0, :num_common_kv_blocks]
429+
shared_kv_page_indices = block_table_tensor[
430+
0, :num_common_kv_blocks]
426431
shared_kv_last_page_len = torch.tensor([page_size],
427432
dtype=torch.int32,
428433
device=device)
429434
# Remove the blocks of the shared prefix from all requests.
430-
block_table = block_table[:, num_common_kv_blocks:]
435+
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
431436
block_table_bounds -= num_common_kv_blocks
432437
else:
433438
shared_qo_indptr = None
434439
shared_kv_page_indptr = None
435440
shared_kv_page_indices = None
436441
shared_kv_last_page_len = None
437442

438-
mask = (torch.arange(block_table.size(1),
439-
dtype=block_table.dtype,
440-
device=block_table.device).unsqueeze(0)
443+
mask = (torch.arange(block_table_tensor.size(1),
444+
dtype=block_table_tensor.dtype,
445+
device=block_table_tensor.device).unsqueeze(0)
441446
< block_table_bounds.unsqueeze(1))
442-
paged_kv_indices = block_table[mask]
447+
paged_kv_indices = block_table_tensor[mask]
443448

444449
paged_kv_indptr = torch.cat([
445450
torch.zeros(1,
@@ -459,10 +464,10 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
459464
paged_kv_indices=paged_kv_indices,
460465
paged_kv_last_page_len=paged_kv_last_page_len,
461466
num_qo_heads=self.runner.num_query_heads,
462-
num_kv_heads=self.runner.num_kv_heads,
463-
head_dim=self.runner.head_size,
467+
num_kv_heads=self.kv_cache_spec.num_kv_heads,
468+
head_dim=self.kv_cache_spec.head_size,
464469
page_size=page_size,
465-
data_type=self.runner.kv_cache_dtype,
470+
data_type=self.kv_cache_spec.dtype,
466471
q_data_type=self.runner.dtype,
467472
slot_mapping=slot_mapping,
468473
num_decodes=self._num_decodes,
@@ -481,7 +486,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
481486
return attn_metadata
482487

483488
def use_cascade_attention(self, *args, **kwargs) -> bool:
484-
if self.runner.kv_cache_dtype != self.runner.model_config.dtype:
489+
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
485490
# TODO: The cascade wrapper currently does not support setting
486491
# kv cache dtype to something different from query dtype.
487492
return False

vllm/v1/attention/backends/mla/common.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@
207207
from vllm.platforms import current_platform
208208
from vllm.utils import cdiv, round_down
209209
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
210+
from vllm.v1.kv_cache_interface import AttentionSpec
211+
from vllm.v1.worker.block_table import BlockTable
210212

211213
try:
212214
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -334,6 +336,8 @@ class MLACommonMetadataBuilder(Generic[M]):
334336

335337
def __init__(self,
336338
runner: "GPUModelRunner",
339+
kv_cache_spec: AttentionSpec,
340+
block_table: BlockTable,
337341
metadata_cls: Optional[type[M]] = None):
338342
self.metadata_cls = metadata_cls \
339343
if metadata_cls is not None else MLACommonMetadata
@@ -346,10 +350,11 @@ def __init__(self,
346350
runner.parallel_config)
347351
self.mla_dims = get_mla_dims(model_config)
348352
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
353+
self.kv_cache_spec = kv_cache_spec
349354

350355
# Dont try to access the runner on AMD
351356
if self.aot_schedule:
352-
self.page_size = self.runner.block_size
357+
self.page_size = self.kv_cache_spec.block_size
353358

354359
if self.chunked_prefill_enabled:
355360
self.chunked_prefill_workspace_size = min(
@@ -375,6 +380,7 @@ def __init__(self,
375380
dtype=model_config.dtype,
376381
device=runner.device,
377382
)
383+
self.block_table = block_table
378384

379385
def reorder_batch(self, input_batch: "InputBatch",
380386
scheduler_output: "SchedulerOutput") -> bool:
@@ -436,9 +442,10 @@ def reorder_batch(self, input_batch: "InputBatch",
436442

437443
return modified_batch
438444

439-
def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor):
445+
def _build_decode(self, block_table_tensor: torch.Tensor,
446+
seq_lens: torch.Tensor):
440447
return MLACommonDecodeMetadata(
441-
block_table=block_table,
448+
block_table=block_table_tensor,
442449
seq_lens=seq_lens,
443450
)
444451

@@ -451,9 +458,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
451458
# function. We should avoid GPU -> CPU sync as much as possible because
452459
# it blocks on all previous kernels.
453460
device = self.runner.device
454-
block_table = (
455-
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
456-
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
461+
block_table = self.block_table
462+
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
463+
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
457464
device, non_blocking=True).long()
458465

459466
query_start_loc = common_attn_metadata.query_start_loc
@@ -530,7 +537,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
530537
self.chunked_prefill_workspace_size
531538

532539
prefill_metadata = MLACommonPrefillMetadata(
533-
block_table=block_table[reqs_start:, ...],
540+
block_table=block_table_tensor[reqs_start:, ...],
534541
query_start_loc=prefill_query_start_loc,
535542
max_query_len=max_query_len,
536543
chunked_context=chunked_context_metadata,
@@ -539,7 +546,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
539546
decode_metadata = None
540547
if self._num_decodes > 0:
541548
decode_metadata = self._build_decode(
542-
block_table=block_table[:self._num_decodes, ...],
549+
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
543550
seq_lens=seq_lens[:self._num_decodes],
544551
)
545552

0 commit comments

Comments
 (0)