Skip to content

[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders #17483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 10, 2025

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Apr 30, 2025

Should merge after #17394

Hybrid allocator will need to build attention metadata for each kv cache group because different kv cache groups may have different attention type and block_table. To achieve that, we will introduce one AttentionMetadataBuilder and one BlockTable for each group.

To prepare for this, this PR makes AttentionMetadataBuilder to access its block_table and KVCacheSpec, instead of reading from model_runner.

And as slot_mapping will also be different for different kv cache groups, this pr moves the slot_mapping_cpu tensor from runner to BlockTable.

Splitted from #16101

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
…kens

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
…tn_metadata

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 30, 2025
Copy link

mergify bot commented Apr 30, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 30, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@mergify mergify bot removed the needs-rebase label May 1, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label May 9, 2025
@WoosukKwon
Copy link
Collaborator

@heheda12345 Please fix the CI failure 😅

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@WoosukKwon WoosukKwon merged commit 950751a into vllm-project:main May 10, 2025
34 of 35 checks passed
@tjtanaa
Copy link
Contributor

tjtanaa commented May 11, 2025

@heheda12345 may I know if the pre-commit check caught this?

the rename of the block_table to block_table_tensor in _build_decode function had broken AiterMLAMetadataBuilder._build_decode

If it doesn't then there might need some investigation to find out the reason why it is not caught in the CI.

vllm-bot pushed a commit that referenced this pull request May 11, 2025
…me in PR #17483 (#17961)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Yikun added a commit to vllm-project/vllm-ascend that referenced this pull request May 11, 2025
#806)

### What this PR does / why we need it?

1. Fix V1 error found by
[nightly_ci](https://github.com/vllm-project/vllm-ascend/actions/runs/14950004754/job/41998136610),
broken by [[v1] Pass BlockTable and KVCacheSpec to
AttentionMetadataBuilders
#17483](vllm-project/vllm#17483), make
`InputBatch` parameter consistent with vllm.
2. Disable benmark and fix it in upstream.

### Does this PR introduce _any_ user-facing change?

No


### How was this patch tested?

CI passed

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
juncgu added a commit to juncgu/vllm that referenced this pull request May 11, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…lm-project#17483)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…me in PR vllm-project#17483 (vllm-project#17961)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
@heheda12345
Copy link
Collaborator Author

@tjtanaa An interesting question. I think pre-commit can pass because we have MLACommonMetadataBuilder._build_decode with the correct signature. Maybe we need more strict type annotations.
A small reproduce script:

class A:

    def func(self, a: int, b: int) -> None:
        pass

    def func2(self, a: int, b: int) -> None:
        self.func(a, b=b)


class B(A):

    def func(self, a: int, bb: int) -> None:
        pass


# B().func(1, b=2) # type error
B().func2(1, 2)  # pass type check but crash like you mentioned.

@chenyang78
Copy link
Contributor

chenyang78 commented May 13, 2025

Looks like this commit caused regression to the FlashInfer backend (with the FlashInfer's latest commit: 25fb40) at least on GB200 and B200. With this commit, the following command failed with CUDA out-of-memory failures:

$ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
...
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 4.50 GiB. ..

It worked fine right before this commit.

@chenyang78
Copy link
Contributor

The same CUDA out-of-memory failure also occurred on H100 with the same command above, i.e.

$ VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto 
...
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.12 GiB. GPU 0 has a total capacity of 95.00 GiB of which 1.01 GiB is free.

BTW, I was using CUDA 12.8. Thanks!

mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…lm-project#17483)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…me in PR vllm-project#17483 (vllm-project#17961)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…lm-project#17483)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…me in PR vllm-project#17483 (vllm-project#17961)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants