Skip to content

[v1] Hybrid Memory Allocator #17996

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
4f81b65
hybrid allocator
heheda12345 May 11, 2025
ec55021
refactor
heheda12345 May 12, 2025
41e027a
fix bug
heheda12345 May 15, 2025
0735539
minor updates
heheda12345 May 15, 2025
5e53840
minor updates
heheda12345 May 15, 2025
dcfe6ca
avoid frequent creation of block_bundle
heheda12345 May 16, 2025
b94ed65
update notes
heheda12345 May 16, 2025
deafbda
fix config
heheda12345 May 16, 2025
18798da
fix tests in v1/core
heheda12345 May 16, 2025
94bd895
mapping as clas attribute
heheda12345 May 20, 2025
e50e33c
simplify KVCacheBlocks
heheda12345 May 23, 2025
5c2887a
unify interface
heheda12345 May 23, 2025
32afa9d
add a tofix
heheda12345 May 24, 2025
c75c9b5
a runable version without prefix caching
heheda12345 May 30, 2025
5800dcf
support prefix caching
heheda12345 May 30, 2025
c96e13b
fix bug for --disable-hybrid-kv-cache-manager
heheda12345 May 30, 2025
d8ad1be
clean up
heheda12345 May 30, 2025
3751688
fix padding calculation
heheda12345 May 30, 2025
159f51c
clean up
heheda12345 May 30, 2025
84f27b9
update code inside vllm/core
heheda12345 Jun 1, 2025
9c04802
update kv cache init
heheda12345 Jun 1, 2025
72c2671
coordinator
heheda12345 Jun 1, 2025
05f3406
add some notes
heheda12345 Jun 2, 2025
3019f1b
add notes about assumptions
heheda12345 Jun 2, 2025
ca6a00b
simplify verify_and_split_kv_cache_groups
heheda12345 Jun 2, 2025
15b4449
update explaination
heheda12345 Jun 2, 2025
1a862f9
BlockHashType->BlockHash
heheda12345 Jun 2, 2025
904bd25
update coordinator
heheda12345 Jun 2, 2025
66032cf
small fix
heheda12345 Jun 2, 2025
395e2bc
small fix
heheda12345 Jun 2, 2025
3556db8
update logging
heheda12345 Jun 2, 2025
b416963
add todo in this pr
heheda12345 Jun 2, 2025
e629ee8
fix tpu backend
heheda12345 Jun 2, 2025
a52d271
pass tests in v1/core
heheda12345 Jun 3, 2025
08e0888
revert previous change in tests/v1/core
heheda12345 Jun 3, 2025
2140dc6
update worker test
heheda12345 Jun 3, 2025
b63d8ea
test_cache_blocks_multi_group
heheda12345 Jun 3, 2025
b64b8b1
test_prefill_hybrid_model
heheda12345 Jun 3, 2025
5fb5e49
revert test_scheduler
heheda12345 Jun 3, 2025
13b486a
revert test_manager
heheda12345 Jun 3, 2025
b598c0e
test_get_kv_cache_config
heheda12345 Jun 3, 2025
ee71bd8
Merge branch 'main' of github.com:vllm-project/vllm into hybrid_alloc…
heheda12345 Jun 4, 2025
85798c5
fix ci
heheda12345 Jun 4, 2025
b5fa8e1
fix kv connector tests
heheda12345 Jun 4, 2025
fa2f7bc
small updates
heheda12345 Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 210 additions & 35 deletions tests/v1/core/test_kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics,
estimate_max_model_len, generate_block_hash_extra_keys,
get_max_concurrency_for_kv_cache_config, hash_block_tokens,
hash_request_tokens, unify_kv_cache_configs)
get_kv_cache_config, get_max_concurrency_for_kv_cache_config,
hash_block_tokens, hash_request_tokens, unify_kv_cache_configs)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
SlidingWindowSpec)
Expand Down Expand Up @@ -63,6 +63,20 @@ def new_kv_cache_spec(block_size=16,
sliding_window=sliding_window)


def new_sliding_window_spec(block_size=16,
num_kv_heads=2,
head_size=64,
dtype=torch.float32,
use_mla=False,
sliding_window=1):
return SlidingWindowSpec(block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
use_mla=use_mla,
sliding_window=sliding_window)


def test_none_hash(monkeypatch):
import vllm.v1.core.kv_cache_utils

Expand Down Expand Up @@ -403,10 +417,10 @@ def test_unify_kv_cache_configs():
same_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
Expand All @@ -415,10 +429,10 @@ def test_unify_kv_cache_configs():
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
Expand All @@ -433,10 +447,10 @@ def test_unify_kv_cache_configs():
need_sort_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
Expand All @@ -445,10 +459,10 @@ def test_unify_kv_cache_configs():
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer2"],
new_kv_cache_spec(num_kv_heads=4)),
Expand All @@ -464,10 +478,10 @@ def test_unify_kv_cache_configs():
diff_kv_cache_config = [
KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
Expand All @@ -476,10 +490,10 @@ def test_unify_kv_cache_configs():
),
KVCacheConfig(
num_blocks=20,
tensors={
"layer1": KVCacheTensor(100),
"layer2": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
KVCacheTensor(size=100, shared_by=["layer2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer2"],
Expand Down Expand Up @@ -636,7 +650,7 @@ def test_get_max_concurrency_for_kv_cache_config():

kv_cache_config_full_attention = KVCacheConfig(
num_blocks=int(1024 * 1.5),
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
full_attention_spec),
Expand All @@ -648,7 +662,7 @@ def test_get_max_concurrency_for_kv_cache_config():

kv_cache_config_sliding_window = KVCacheConfig(
num_blocks=129 * 3,
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
sliding_window_spec),
Expand All @@ -660,7 +674,7 @@ def test_get_max_concurrency_for_kv_cache_config():

kv_cache_config_hybrid_model = KVCacheConfig(
num_blocks=(1024 + 129) * 3,
tensors={},
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec([f"layer_{i}" for i in range(32)],
full_attention_spec),
Expand All @@ -678,9 +692,9 @@ def test_allocate_with_lookahead():
block_size = 4
config = KVCacheConfig(
num_blocks=10,
tensors={
"layer1": KVCacheTensor(100),
},
kv_cache_tensors=[
KVCacheTensor(size=100, shared_by=["layer1"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer1"],
new_kv_cache_spec(block_size=block_size)),
Expand All @@ -702,7 +716,7 @@ def test_allocate_with_lookahead():
num_new_tokens=3,
num_lookahead_tokens=2, # Total required: 3+2=5 tokens
)
assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks
assert len(blocks.get_block_ids()[0]) == 2 # ceil(5/4)=2 blocks

# Test case 2: With precomputed blocks
kv_cache_manager = KVCacheManager(kv_cache_config=config,
Expand All @@ -713,7 +727,7 @@ def test_allocate_with_lookahead():
num_new_tokens=3,
num_lookahead_tokens=2,
)
assert len(blocks.blocks) == 2
assert len(blocks.get_block_ids()[0]) == 2

# Test case 3: With precomputed blocks
# required_blocks = ceil((3 + 4) / 4) = 2
Expand All @@ -724,4 +738,165 @@ def test_allocate_with_lookahead():
num_new_tokens=3,
num_lookahead_tokens=4,
)
assert len(blocks.blocks) == 2
assert len(blocks.get_block_ids()[0]) == 2


def test_get_kv_cache_config():
# pass max_model_len to pass check_enough_kv_cache_memory
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config)

mem_per_block_per_layer = 16 * 2 * 64 * 4 * 2
# all layers are full attention -> single group
kv_cache_specs_full = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
}
kv_cache_config_full = get_kv_cache_config(
vllm_config, kv_cache_specs_full, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_full == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec())
])

# all layers are sliding window -> single group
kv_cache_specs_sliding = {
'layer_1': new_sliding_window_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_sliding = get_kv_cache_config(
vllm_config, kv_cache_specs_sliding, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_sliding == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_sliding_window_spec())
])

# full + sliding, but disable_hybrid_kv_cache_manager
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = True
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"],
new_kv_cache_spec(sliding_window=1)),
],
)
vllm_config.scheduler_config.disable_hybrid_kv_cache_manager = False

# full + sliding, with hybrid_kv_cache_manager
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=64,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 64,
shared_by=["layer_1", "layer_2"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer_2"], new_sliding_window_spec()),
],
)

# 2 full + 4 sliding, 2 layers per group
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
'layer_3': new_sliding_window_spec(),
'layer_4': new_sliding_window_spec(),
'layer_5': new_sliding_window_spec(),
'layer_6': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 2 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_3", "layer_5"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_4", "layer_6"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()),
KVCacheGroupSpec(["layer_3", "layer_4"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_5", "layer_6"],
new_sliding_window_spec()),
],
)

# 3 full + 7 sliding, pad to 3 full + 9 sliding
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(),
'layer_2': new_kv_cache_spec(),
'layer_3': new_kv_cache_spec(),
'layer_4': new_sliding_window_spec(),
'layer_5': new_sliding_window_spec(),
'layer_6': new_sliding_window_spec(),
'layer_7': new_sliding_window_spec(),
'layer_8': new_sliding_window_spec(),
'layer_9': new_sliding_window_spec(),
'layer_10': new_sliding_window_spec(),
}
kv_cache_config_hybrid = get_kv_cache_config(
vllm_config, kv_cache_specs_hybrid, mem_per_block_per_layer * 3 * 32)
assert kv_cache_config_hybrid == KVCacheConfig(
num_blocks=32,
kv_cache_tensors=[
KVCacheTensor(
size=mem_per_block_per_layer * 32,
shared_by=["layer_1", "layer_4", "layer_7", "layer_10"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_2", "layer_5", "layer_8"]),
KVCacheTensor(size=mem_per_block_per_layer * 32,
shared_by=["layer_3", "layer_6", "layer_9"]),
],
kv_cache_groups=[
KVCacheGroupSpec(["layer_1", "layer_2", "layer_3"],
new_kv_cache_spec()),
KVCacheGroupSpec(["layer_4", "layer_5", "layer_6"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_7", "layer_8", "layer_9"],
new_sliding_window_spec()),
KVCacheGroupSpec(["layer_10"], new_sliding_window_spec()),
],
)

# different hidden size, unimplemented
kv_cache_specs_hybrid = {
'layer_1': new_kv_cache_spec(head_size=128),
'layer_2': new_kv_cache_spec(),
}
with pytest.raises(NotImplementedError):
get_kv_cache_config(vllm_config, kv_cache_specs_hybrid,
mem_per_block_per_layer * 2 * 32)
Loading