Skip to content

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Jul 25, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Support multiple attention metadata builders per kv-cache spec so we can undo the hacky fix in #21707

Test Plan

Use the same ruler task as in: #21707

Test Result

python -m lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=4,gpu_memory_utilization=0.8,trust_remote_code=True,max_model_len=32768,disable_hybrid_kv_cache_manager=True --tasks ruler_qa_squad --limit 100 --batch_size auto --metadata='{"max_seq_lengths":[16384]}'
...
|    Tasks     |Version|Filter|n-shot|Metric|   | Value |   |Stderr|
|--------------|------:|------|-----:|-----:|---|------:|---|------|
|ruler_qa_squad|      1|none  |     0| 16384|↑  | 0.7092|±  |   N/A|
|              |       |none  |     0|  4096|↑  |-1.0000|±  |   N/A|
lm_eval --model vllm --model_args '{"pretrained": "google/gemma-3n-E2B-it"}' --tasks gsm8k --batch_size auto
...
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6406|±  |0.0132|
|     |       |strict-match    |     5|exact_match|↑  |0.6399|±  |0.0132|

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 25, 2025
@LucasWilkinson LucasWilkinson changed the title [WIP] local attention no hybrid kv cache [WIP] local attention no hybrid kv cache + support multiple attention metadata builders per kv_cache_spec Jul 25, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

An excellent and comprehensive refactoring effort! The introduction of AttentionGroup and the dynamic wrapping for local attention significantly improve the modularity and maintainability of the attention backend handling. This new approach is much cleaner for supporting heterogeneous attention mechanisms within a model.

I've identified one area for improvement to enhance the robustness of the new implementation. Please see my detailed comment below.

Once that's addressed, this looks like a solid contribution.

self.attn_metadata_builders
) == 0, "Attention backends are already initialized"
for i, kv_cache_group_spec in enumerate(
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The previous implementation of this function included an assertion to ensure it's not called more than once (assert len(self.attn_backends) == 0 and len(self.attn_metadata_builders) == 0). This assertion has been removed in the refactoring.

To maintain robustness and prevent accidental re-initialization which could lead to a corrupted state, it would be good to add a similar assertion back, checking the new state attributes (self.attn_groups or self.attn_groups_dict).

assert not self.attn_groups, "Attention backends are already initialized"
        attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)

Copy link
Collaborator

@sarckk sarckk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, this is looking great


for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
for layer_name in attn_group.layer_names:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need further changes to support cross-layer KV sharing. Previous to this PR, we add the KV-reusing layers to .layer_names of the KV cache group of the target layer, which ensures that attn_metadata is populated for these layers. With this PR, this depends on attn_group.layer_names which is populated before KV sharing logic is executed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you describe the changes needed? and the best model/command to test them with? that would be super helpful (in not that spun-up on kv-cache sharing)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, you can try it out with gemma3n:

vllm serve google/gemma-3n-E2B-it --disable-log-requests

or run the unit test:

pytest tests/v1/worker/test_gpu_model_runner.py -k "test_init_kv_cache"

But it looks like you've already handled this in your latest commits (assuming one attention group per KV cache group)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested!

lm_eval --model vllm --model_args '{"pretrained": "google/gemma-3n-E2B-it"}' --tasks gsm8k --batch_size auto
...
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6406|±  |0.0132|
|     |       |strict-match    |     5|exact_match|↑  |0.6399|±  |0.0132|

#print(f"Building local attention metadata builder {type(self)}")
# TODO(lucas): this requires the attention metadata builder save the
# kv_cache_spec, as an attribute; we maybe can do something better here
common_attn_metadata = make_local_attention_virtual_batches(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should also think about how different transformations on the common attn metadata can compose with each other. e.g. for yoco we need to modify the metadata prior to make_local_attention_virtual_batches.

I guess we can stack these like: make_decode_only_metadata_builder(make_local_attention_metadata_builder(...))

@mergify mergify bot added the llama Related to Llama models label Jul 26, 2025
Copy link

mergify bot commented Jul 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/local-attention-no-hybrid-kv-cache branch from 57905cf to a5f9db2 Compare July 29, 2025 05:32
@mergify mergify bot removed the needs-rebase label Jul 29, 2025
Copy link

mergify bot commented Jul 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 29, 2025
assert len(attn_groups[group_idx]) == 1, (
"Only one attention group per KV cache group is supported "
"for KV-cache sharing for now.")
# TODO(lucas): I think in the future the layers that re-use a
Copy link
Collaborator

@sarckk sarckk Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To flesh this out a bit more, I'm not sure layers re-using KV cache should always be placed in a separate attention group. Let's say we have the following 8-layer config:

L0: Full Attn
L1: Local Attn
L2: Reuse L0 (Full Attn)
L3: Reuse L1 (Local Attn)
L4: Full Attn
L5: Local Attn
L6: Reuse L4 (Full Attn)
L7: Reuse L5 (Local Attn)

Then without hybrid KV cache we should have attn_groups looking like:

[
  [
    AttentionGroup(FullAttnBackend, layers=[L0, L4]),
    AttentionGroup(LocalAttnBackend, layers=[L1, L5])
  ],
]

Where should L2, L3, L6, L7 be added? L6 and L7 qualify for faster prefill by virtue of being trailing KV-sharing layers, so they need a separate attention metadata builder (and thus a separate AttentionGroup). so in my mind, it would look like the following:

[
  [
    AttentionGroup(FullAttnBackend, layers=[L0, L4, L2]),
    AttentionGroup(LocalAttnBackend, layers=[L1, L5, L3]),
    AttentionGroup(**FasterPrefill**FullAttnBackend, layers=[L6]),
    AttentionGroup(**FasterPrefill**LocalAttnBackend, layers=[L7]),
  ],
]

Layers L2 and L3 have been placed in the respective AttentionGroup of its target layer. L6 and L7 are special cases (qualify for faster prefill), so they have been added as two separate AttentionGroups.

WDYT? cc: @heheda12345

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry im not totally spun up on this FasterPrefill optimization; why dont L2 and L3 qualify?

overall though this makes sense to me I think

Copy link
Collaborator

@sarckk sarckk Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why dont L2 and L3 qualify?

L2, L3, L6 and L7 all use cross-attention to reuse the shared KV caches (let's refer to them as cross-attention layers). L0, L1, L4, L5 are self-attention layers.

For cross-attention layers, during decoding we only need the KV caches of the target layers to be populated for cross-attention. This means we can apply an optimization where given N prompt tokens for a given request, we can skip N-1 tokens during prefill as each cross-attention layer only depends on its target layer having done attention with full N tokens.

For L6, its target layer is L4. L4 will do prefill with full N tokens, so L6 can just work with the single last token. Same with L7. However, for L2 and L3, we have self-attention layers L4 and L5 that come after it, which require its KV caches to be populated with the full N prompt tokens (for L6 and L7 cross-attention). This means that L4 and L5 forward() must return the correct logits for full N tokens, so it cannot skip the first N-1 tokens.

Copy link
Collaborator Author

@LucasWilkinson LucasWilkinson Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah makes sense; thanks for the detailed explanation!

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/local-attention-no-hybrid-kv-cache branch from a5f9db2 to 6036f68 Compare July 30, 2025 03:48
@LucasWilkinson LucasWilkinson marked this pull request as ready for review July 30, 2025 03:51
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/local-attention-no-hybrid-kv-cache branch from 91d1ca9 to 2a7975f Compare July 30, 2025 03:53
@mergify mergify bot removed the needs-rebase label Jul 30, 2025
@LucasWilkinson LucasWilkinson changed the title [WIP] local attention no hybrid kv cache + support multiple attention metadata builders per kv_cache_spec [Attention] local attention no hybrid kv cache + support multiple attention metadata builders per kv_cache_spec Jul 30, 2025
@LucasWilkinson LucasWilkinson changed the title [Attention] local attention no hybrid kv cache + support multiple attention metadata builders per kv_cache_spec [Attention] Support multiple attention metadata builders per kv_cache_spec + proper local attention no hybrid kv cache fix Jul 30, 2025
Copy link

mergify bot commented Jul 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@vllm-bot vllm-bot merged commit 1dc8a70 into vllm-project:main Aug 7, 2025
43 of 45 checks passed
nvjullin pushed a commit to nvjullin/vllm that referenced this pull request Aug 7, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
jingyu-ml pushed a commit to jingyu-ml/vllm that referenced this pull request Aug 8, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: jingyu <jingyu@omniml.ai>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Noam Gat <noamgat@gmail.com>
sarckk added a commit to sarckk/vllm that referenced this pull request Aug 11, 2025
Summary:
vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec. 

As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used. 

However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion:
```
            assert len(attn_groups[group_idx]) == 1, (
                "Only one attention group per KV cache group is supported "
                "for KV-cache sharing for now.")
```

This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group.

Test Plan:
new added unit test passes:
```
pytest tests/v1/test_kv_sharing.py 
```

Rollback Plan:

Differential Revision: D80020191
sarckk added a commit to sarckk/vllm that referenced this pull request Aug 11, 2025
Summary:

vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec. 

As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used. 

However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion:
```
            assert len(attn_groups[group_idx]) == 1, (
                "Only one attention group per KV cache group is supported "
                "for KV-cache sharing for now.")
```

This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group.

Test Plan:
new added unit test passes:
```
pytest tests/v1/test_kv_sharing.py 
```

Rollback Plan:

Differential Revision: D80020191
sarckk added a commit to sarckk/vllm that referenced this pull request Aug 11, 2025
Summary:

vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec. 

As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used. 

However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion:
```
            assert len(attn_groups[group_idx]) == 1, (
                "Only one attention group per KV cache group is supported "
                "for KV-cache sharing for now.")
```

This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group.

Test Plan:
new added unit test passes:
```
pytest tests/v1/test_kv_sharing.py 
```

Rollback Plan:

Differential Revision: D80020191
yyihuang pushed a commit to yyihuang/vllm that referenced this pull request Aug 11, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Avery Yingyi Huang <yingyihuang2000@outlook.com>
wuhang2014 pushed a commit to wuhang2014/vllm that referenced this pull request Aug 12, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
aarnphm pushed a commit to aarnphm/vllm that referenced this pull request Aug 13, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
sarckk added a commit to sarckk/vllm that referenced this pull request Aug 14, 2025
Summary:

vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec. 

As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used. 

However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion:
```
            assert len(attn_groups[group_idx]) == 1, (
                "Only one attention group per KV cache group is supported "
                "for KV-cache sharing for now.")
```

This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group.

Test Plan:
new added unit test passes:
```
pytest tests/v1/test_kv_sharing.py 
```

Rollback Plan:

Differential Revision: D80020191
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Aug 14, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
sarckk added a commit to sarckk/vllm that referenced this pull request Aug 15, 2025
Summary:

vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec.

As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used.

However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion:
```
            assert len(attn_groups[group_idx]) == 1, (
                "Only one attention group per KV cache group is supported "
                "for KV-cache sharing for now.")
```

This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group.

Test Plan:
new added unit test passes:
```
pytest tests/v1/test_kv_sharing.py
```

Rollback Plan:

Differential Revision: D80020191

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
sarckk added a commit to sarckk/vllm that referenced this pull request Aug 15, 2025
Summary:

vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec. 

As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used. 

However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion:
```
            assert len(attn_groups[group_idx]) == 1, (
                "Only one attention group per KV cache group is supported "
                "for KV-cache sharing for now.")
```

This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group.

Test Plan:
new added unit test passes:
```
pytest tests/v1/test_kv_sharing.py 
```

Rollback Plan:

Differential Revision: D80020191
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
sarckk added a commit to sarckk/vllm that referenced this pull request Aug 15, 2025
Summary:

vllm-project#21588 added support for multiple attention metadata builders per kv-cache spec.

As part of this change, each KV cache group now maps to one or more `AttentionGroup`, with one attention group being created for each type of attention backend used.

However, if we want to enable KV sharing when we have more than one attention group, we run into the following assertion:
```
            assert len(attn_groups[group_idx]) == 1, (
                "Only one attention group per KV cache group is supported "
                "for KV-cache sharing for now.")
```

This PR adds support to make this implementation more flexible, such that we can support KV cache sharing when there are multiple attention groups per KV cache group.

Test Plan:
new added unit test passes:
```
pytest tests/v1/test_kv_sharing.py
```

Rollback Plan:

Differential Revision: D80020191

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
juuice-lee pushed a commit to juuice-lee/vllm-moe.code that referenced this pull request Aug 18, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
dumb0002 pushed a commit to dumb0002/vllm that referenced this pull request Aug 28, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
…_spec + proper local attention no hybrid kv cache fix (vllm-project#21588)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants