Skip to content

[v1] Implement HybridKVCacheManager to support hybrid models with different KV cache type #16101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c6a2d25
copy manager code
heheda12345 Apr 1, 2025
4b27c82
save
heheda12345 Apr 5, 2025
4dea38d
can run
heheda12345 Apr 5, 2025
55720e0
can pass e2e tests
heheda12345 Apr 5, 2025
273dd44
run precommit
heheda12345 Apr 5, 2025
0bfec8d
can run again
heheda12345 Apr 5, 2025
df31d7a
quick copy
heheda12345 Apr 6, 2025
1ce3023
Merge branch 'main' of github.com:vllm-project/vllm into hybrid_mem
heheda12345 Apr 23, 2025
6aee98d
a runable version
heheda12345 Apr 23, 2025
7f19466
fix bug
heheda12345 Apr 23, 2025
34ba571
1 hash per block_size
heheda12345 Apr 23, 2025
18245e3
one manager for each type
heheda12345 Apr 24, 2025
2c81fe6
small update
heheda12345 Apr 24, 2025
42a8244
small fix
heheda12345 Apr 24, 2025
8af9ace
Merge branch 'main' of github.com:vllm-project/vllm into hybrid_mem
heheda12345 Apr 25, 2025
6493e5e
fix gemma
heheda12345 Apr 25, 2025
fa224f2
Merge branch 'fix_gemma' of github.com:heheda12345/vllm into hybrid_mem
heheda12345 Apr 25, 2025
c512bc5
update attn backends
heheda12345 Apr 25, 2025
4ce3424
fix flashinfer type
heheda12345 Apr 25, 2025
d17843e
fix flashmla type
heheda12345 Apr 25, 2025
47ec1a7
fix triton type
heheda12345 Apr 25, 2025
840675f
clean up slidingwindowspec
heheda12345 Apr 25, 2025
ffcbde8
clean up block table
heheda12345 Apr 25, 2025
4eebce7
clean up runner (WIP)
heheda12345 Apr 25, 2025
4380fa6
add notes
heheda12345 Apr 25, 2025
b567c56
clean up attn_metadata read in runner
heheda12345 Apr 25, 2025
710d68e
reorder attn args
heheda12345 Apr 25, 2025
2b8ffc4
rename max_num_tokens
heheda12345 Apr 25, 2025
b50aa14
remove fixed TODO
heheda12345 Apr 25, 2025
136a54c
fix
heheda12345 Apr 25, 2025
765d9ed
add note
heheda12345 Apr 25, 2025
216a079
group partition strategy
heheda12345 Apr 26, 2025
1c66541
Merge branch 'main' of github.com:vllm-project/vllm into hybrid_mem
heheda12345 Apr 26, 2025
37c4494
support eagle
heheda12345 Apr 26, 2025
84280fc
get a specific type of layer from forward context
heheda12345 Apr 26, 2025
51ffeb6
fix
heheda12345 Apr 26, 2025
7d03c88
Merge branch 'filter_fwd_ctx' of github.com:heheda12345/vllm into hyb…
heheda12345 Apr 26, 2025
1da28d9
update eagle
heheda12345 Apr 26, 2025
e5cb02e
only enable cuda platform
heheda12345 Apr 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/offline_inference/basic/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="google/gemma-3-1b-it", enforce_eager=True)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ class TestConfig:

model_config = {
"bigcode/starcoder2-3b": TestConfig(4096, (800, 1100)),
"google/gemma-2-2b-it": TestConfig(4096, (400, 800)),
"google/gemma-3-1b-it": TestConfig(4096, (400, 800)), # TODO: swa 1024
}


@pytest.mark.parametrize(
"model",
[
"bigcode/starcoder2-3b", # sliding window only
"google/gemma-2-2b-it", # sliding window + full attention
# "bigcode/starcoder2-3b", # sliding window only
"google/gemma-3-1b-it", # sliding window + full attention
])
@pytest.mark.parametrize("batch_size", [5])
@pytest.mark.parametrize("seed", [1])
Expand Down
6 changes: 2 additions & 4 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
is_block_tables_empty)
from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad)
Expand Down Expand Up @@ -128,12 +128,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""

layers = vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: Dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
assert isinstance(layer, Attention)

impl = layer.impl
assert isinstance(impl, FlashInferImpl)

Expand Down
8 changes: 8 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ def forward(
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
Expand All @@ -226,6 +228,8 @@ def forward(
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
Expand Down Expand Up @@ -374,6 +378,8 @@ def unified_attention(

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
output = self.impl.forward(self, query, key, value, kv_cache,
Expand Down Expand Up @@ -411,6 +417,8 @@ def unified_attention_with_output(
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
Expand Down
15 changes: 15 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,8 @@ class CacheConfig:
"""The number of blocks to allocate for GPU memory."""
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""
disable_hybrid_allocator: bool = False
"""Whether to disable the hybrid allocator (Only affects v1)."""

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -4075,3 +4077,16 @@ def assert_hashable(text):
f"vLLM tried to hash some configs that may have Python objects ids "
f"in them. This is a bug, please file an issue. "
f"Text being hashed: {text}")


T = TypeVar("T")


def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}
8 changes: 8 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class EngineArgs:
model_impl: str = "auto"

calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
disable_hybrid_allocator: bool = False

additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
Expand Down Expand Up @@ -948,6 +949,12 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")

parser.add_argument(
"--disable-hybrid-allocator",
action="store_true",
default=False,
help="Disable the hybrid allocator. This only affects v1.")

parser.add_argument(
"--additional-config",
type=json.loads,
Expand Down Expand Up @@ -1148,6 +1155,7 @@ def create_engine_config(
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
disable_hybrid_allocator=self.disable_hybrid_allocator,
)

# Get the current placement group if Ray is initialized and
Expand Down
11 changes: 8 additions & 3 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -38,8 +38,13 @@ class DPMetadata:
class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any]
# TODO: extend to support per-layer dynamic forward context
attn_metadata: "AttentionMetadata" # set dynamically for each forward pass
"""
Type AttentionMetadata for v0,
Type Dict[str, AttentionMetadata] for v1, mapping from layer_name to
AttentionMetadata of that layer
set dynamically for each forward pass
"""
attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]]
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
Expand Down
36 changes: 20 additions & 16 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)

Expand Down Expand Up @@ -278,7 +281,8 @@ def make_local_attention_virtual_batches(

class FlashAttentionMetadataBuilder:

def __init__(self, runner: "GPUModelRunner"):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
model_config = runner.model_config

self.runner = runner
Expand All @@ -288,23 +292,23 @@ def __init__(self, runner: "GPUModelRunner"):
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size
self.page_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True)
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
Expand All @@ -328,12 +332,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
local_attn_metadata = None
if self.runner.attention_chunk_size is not None:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table = make_local_attention_virtual_batches(
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
block_table,
self.runner.block_size,
block_table_tensor,
self.kv_cache_spec.block_size,
)
local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to(
self.runner.device, non_blocking=True)
Expand All @@ -352,7 +356,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=local_query_start_loc,
local_seqused_k=local_seqused_k,
local_block_table=virt_block_table,
local_block_table=virt_block_table_tensor,
local_max_query_len=local_max_query_len,
local_max_seq_len=local_max_seq_len,
local_scheduler_metadata=local_scheduler_metadata,
Expand Down Expand Up @@ -403,7 +407,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
Expand Down
51 changes: 27 additions & 24 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -81,12 +84,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""

layers = vllm_config.compilation_config.static_forward_context
layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {}

for key, layer in layers.items():
assert isinstance(layer, Attention)

impl = layer.impl
assert isinstance(impl, FlashInferImpl)

Expand Down Expand Up @@ -205,7 +206,8 @@ def __post_init__(self):

class FlashInferMetadataBuilder:

def __init__(self, runner: GPUModelRunner):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
Expand All @@ -215,7 +217,9 @@ def __init__(self, runner: GPUModelRunner):
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None

self.vllm_config = get_current_vllm_config()
self.vllm_config = runner.vllm_config
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table

def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
Expand Down Expand Up @@ -398,19 +402,17 @@ def _plan(self, attn_metadata: FlashInferMetadata):
)

def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.runner.block_size
page_size = self.kv_cache_spec.block_size
device = self.runner.device
qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = (self.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

block_table_bounds = (seq_lens + page_size - 1) // page_size
Expand All @@ -426,24 +428,25 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
dtype=torch.int32,
device=device)
shared_kv_page_indices = block_table[0, :num_common_kv_blocks]
shared_kv_page_indices = block_table_tensor[
0, :num_common_kv_blocks]
shared_kv_last_page_len = torch.tensor([page_size],
dtype=torch.int32,
device=device)
# Remove the blocks of the shared prefix from all requests.
block_table = block_table[:, num_common_kv_blocks:]
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds -= num_common_kv_blocks
else:
shared_qo_indptr = None
shared_kv_page_indptr = None
shared_kv_page_indices = None
shared_kv_last_page_len = None

mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=block_table.device).unsqueeze(0)
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]
paged_kv_indices = block_table_tensor[mask]

paged_kv_indptr = torch.cat([
torch.zeros(1,
Expand All @@ -462,9 +465,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.runner.num_query_heads,
num_kv_heads=self.runner.num_kv_heads,
head_dim=self.runner.head_size,
num_qo_heads=self.kv_cache_spec.num_query_heads,
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size,
data_type=self.runner.kv_cache_dtype,
q_data_type=self.runner.dtype,
Expand Down
Loading