Skip to content

[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders #17483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d35146f
remove num_input_tokens from attn_metadata
heheda12345 Apr 25, 2025
20d930b
fix
heheda12345 Apr 25, 2025
d17daf5
Merge branch 'main' of github.com:vllm-project/vllm into num_input_to…
heheda12345 Apr 29, 2025
f0636df
per_layer_attn_metadata
heheda12345 Apr 29, 2025
1e2f970
Merge branch 'num_input_tokens' of github.com:heheda12345/vllm into p…
heheda12345 Apr 29, 2025
dd08b5b
updaet comment
heheda12345 Apr 29, 2025
ab4389e
update tpu code
heheda12345 Apr 29, 2025
20a1d22
fix kv connector
heheda12345 Apr 29, 2025
e7ffa63
Merge branch 'main' of github.com:vllm-project/vllm into per_layer_at…
heheda12345 Apr 29, 2025
5816b17
Merge branch 'main' of github.com:vllm-project/vllm into per_layer_at…
heheda12345 Apr 30, 2025
4679b4c
fix eagle
heheda12345 Apr 30, 2025
e484ba9
update test
heheda12345 Apr 30, 2025
29bc590
move slot_mapping to block_table
heheda12345 Apr 30, 2025
2befc41
Merge branch 'per_layer_attn_metadata' of github.com:heheda12345/vllm…
heheda12345 Apr 30, 2025
5fea105
update attn backends
heheda12345 Apr 30, 2025
ff2c72d
Merge branch 'main' of github.com:vllm-project/vllm into slot_mapping
heheda12345 May 1, 2025
178ed85
revert scheduler changes
heheda12345 May 1, 2025
84a95f8
revert unrelated changes
heheda12345 May 1, 2025
b5b7ae4
Merge branch 'main' of github.com:vllm-project/vllm into slot_mapping
heheda12345 May 6, 2025
084599c
clean up
heheda12345 May 6, 2025
86642a6
clean up
heheda12345 May 6, 2025
5b2bb9a
Merge branch 'main' of github.com:vllm-project/vllm into slot_mapping
heheda12345 May 9, 2025
b59d22c
revert num_query_heads
heheda12345 May 9, 2025
6eea844
clean up tpu code
heheda12345 May 9, 2025
e999b50
revert
heheda12345 May 9, 2025
112a772
small fix of mla
heheda12345 May 9, 2025
ec2efcc
small fix of mla
heheda12345 May 9, 2025
d54557a
clean up
heheda12345 May 9, 2025
652f3a9
remove empty line
heheda12345 May 9, 2025
997165b
page_size -> block_size
heheda12345 May 10, 2025
fbf05f5
fix tests
heheda12345 May 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/v1/worker/test_gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
Expand Down Expand Up @@ -310,6 +311,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
Expand All @@ -318,6 +320,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
Expand Down
21 changes: 20 additions & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
import weakref

import pytest
import torch

from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner


def initialize_kv_cache(runner: GPUModelRunner):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
kv_cache_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)


@pytest.fixture
def model_runner():
scheduler_config = SchedulerConfig(
Expand Down Expand Up @@ -38,7 +55,9 @@ def model_runner():
)

device = "cuda"
return GPUModelRunner(vllm_config, device)
runner = GPUModelRunner(vllm_config, device)
initialize_kv_cache(runner)
return runner


def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
Expand Down
47 changes: 30 additions & 17 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -167,7 +169,7 @@ def make_local_attention_virtual_batches(
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.Tensor,
page_size: int = 0,
block_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
Expand Down Expand Up @@ -238,14 +240,14 @@ def make_local_attention_virtual_batches(
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts = k_seqstarts_absolute // page_size
assert attn_chunk_size % page_size == 0, \
block_starts = k_seqstarts_absolute // block_size
assert attn_chunk_size % block_size == 0, \
f"attn_chunk_size {attn_chunk_size} is not " \
f"divisible by page_size {page_size}"
pages_per_local_batch = attn_chunk_size // page_size
f"divisible by block_size {block_size}"
pages_per_local_batch = attn_chunk_size // block_size

# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming page_size=2):
# For out example if we have a block-table like (assuming block_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
Expand Down Expand Up @@ -289,7 +291,8 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder:

def __init__(self, runner: "GPUModelRunner"):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config
compilation_config = runner.vllm_config.compilation_config

Expand All @@ -299,7 +302,9 @@ def __init__(self, runner: "GPUModelRunner"):
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table

if get_flash_attn_version() == 3:
self.aot_schedule = not compilation_config.full_cuda_graph
Expand All @@ -323,9 +328,17 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]

block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)

slot_mapping = block_table.slot_mapping[:num_actual_tokens]

if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
Expand Down Expand Up @@ -354,7 +367,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.page_size,
page_size=self.block_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
Expand All @@ -365,12 +378,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table = make_local_attention_virtual_batches(
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
block_table,
self.runner.block_size,
block_table_tensor,
self.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
Expand All @@ -389,7 +402,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=local_scheduler_metadata,
Expand Down Expand Up @@ -440,7 +453,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
Expand Down
35 changes: 20 additions & 15 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -202,7 +204,8 @@ def __post_init__(self):

class FlashInferMetadataBuilder:

def __init__(self, runner: GPUModelRunner):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
Expand All @@ -213,6 +216,8 @@ def __init__(self, runner: GPUModelRunner):
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = get_current_vllm_config()
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table

def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
Expand Down Expand Up @@ -400,13 +405,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.runner.block_size
page_size = self.kv_cache_spec.block_size
device = self.runner.device
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

block_table_bounds = (seq_lens + page_size - 1) // page_size
Expand All @@ -422,24 +426,25 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
dtype=torch.int32,
device=device)
shared_kv_page_indices = block_table[0, :num_common_kv_blocks]
shared_kv_page_indices = block_table_tensor[
0, :num_common_kv_blocks]
shared_kv_last_page_len = torch.tensor([page_size],
dtype=torch.int32,
device=device)
# Remove the blocks of the shared prefix from all requests.
block_table = block_table[:, num_common_kv_blocks:]
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds -= num_common_kv_blocks
else:
shared_qo_indptr = None
shared_kv_page_indptr = None
shared_kv_page_indices = None
shared_kv_last_page_len = None

mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=block_table.device).unsqueeze(0)
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]
paged_kv_indices = block_table_tensor[mask]

paged_kv_indptr = torch.cat([
torch.zeros(1,
Expand All @@ -459,10 +464,10 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.runner.num_query_heads,
num_kv_heads=self.runner.num_kv_heads,
head_dim=self.runner.head_size,
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size,
data_type=self.runner.kv_cache_dtype,
data_type=self.kv_cache_spec.dtype,
q_data_type=self.runner.dtype,
slot_mapping=slot_mapping,
num_decodes=self._num_decodes,
Expand All @@ -481,7 +486,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
return attn_metadata

def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.runner.kv_cache_dtype != self.runner.model_config.dtype:
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype.
return False
Expand Down
23 changes: 15 additions & 8 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down Expand Up @@ -334,6 +336,8 @@ class MLACommonMetadataBuilder(Generic[M]):

def __init__(self,
runner: "GPUModelRunner",
kv_cache_spec: AttentionSpec,
block_table: BlockTable,
metadata_cls: Optional[type[M]] = None):
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
Expand All @@ -346,10 +350,11 @@ def __init__(self,
runner.parallel_config)
self.mla_dims = get_mla_dims(model_config)
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3)
self.kv_cache_spec = kv_cache_spec

# Dont try to access the runner on AMD
if self.aot_schedule:
self.page_size = self.runner.block_size
self.page_size = self.kv_cache_spec.block_size

if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
Expand All @@ -375,6 +380,7 @@ def __init__(self,
dtype=model_config.dtype,
device=runner.device,
)
self.block_table = block_table

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
Expand Down Expand Up @@ -436,9 +442,10 @@ def reorder_batch(self, input_batch: "InputBatch",

return modified_batch

def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor):
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
block_table=block_table,
block_table=block_table_tensor,
seq_lens=seq_lens,
)

Expand All @@ -451,9 +458,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()

query_start_loc = common_attn_metadata.query_start_loc
Expand Down Expand Up @@ -530,7 +537,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
self.chunked_prefill_workspace_size

prefill_metadata = MLACommonPrefillMetadata(
block_table=block_table[reqs_start:, ...],
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
Expand All @@ -539,7 +546,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = self._build_decode(
block_table=block_table[:self._num_decodes, ...],
block_table_tensor=block_table_tensor[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)

Expand Down
Loading