Skip to content

Conversation

@ayushsatyam146
Copy link
Contributor

@ayushsatyam146 ayushsatyam146 commented Aug 24, 2025

Purpose

Solves #23130. This change fixes a critical bug in vLLM's cascade attention optimization in the V1 arch. The bug is in get_num_common_prefix_blocks(), which determines how many KV cache blocks are shared among all currently running requests to enable cascade attention optimizations.

Changes made

  • Replace ref_cnt-based common prefix detection with running request tracking
  • Update get_num_common_prefix_blocks() to accept running_request_ids set
  • Fix FullAttentionManager to count actual references from running requests
  • Prevent incorrect cascade attention when async KV offloading delays cleanup

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a critical bug in cascade attention related to asynchronous KV transfer by replacing the unreliable ref_cnt-based logic with explicit tracking of running requests. The changes are well-contained and logically sound. My review includes one suggestion to optimize the performance of the new common prefix block calculation, which could be a bottleneck in scenarios with many concurrent requests.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even after block in self.req_to_blocks[req_id] is fixed, I'm still concern about the performance when all requests are sharing a very long prefix. The time complexity is num_requests x num_blocks_per_request. What about passing in the requests that are not running but are during kv transfer?

@ayushsatyam146 ayushsatyam146 force-pushed the kv-cache-fix branch 4 times, most recently from 542d108 to 8adac43 Compare August 29, 2025 04:48
@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 @njhill The time complexity of the new code is O(RxB) now, which was O(RxB²) in the previous iteration. I have one caching based implementation as well in mind which will bring down the complexity to O(1) best case and O(RxB) worst case. But that makes the code a little complex for this module hence I did not want to push that version without someone's approval. PTAL if this is fine or if we need to improve this further? Thanks!

@heheda12345
Copy link
Collaborator

My example code is O((num_transfering_request+1) * num_common_blocks). It should be much faster than num_running_request * num_common_blocks for short requests.

@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 I did the changes your way this time and have pushed it as well. Please take a look, Thanks!

@ayushsatyam146 ayushsatyam146 force-pushed the kv-cache-fix branch 4 times, most recently from 64ed09d to 0d66b57 Compare September 2, 2025 03:53
@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 just a gentle reminder to please take a look and approve if everything is right. Thanks!

@heheda12345
Copy link
Collaborator

@ayushsatyam146 Hi, can you help to update this PR?

@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 sorry I got sick this week and couldn't work on this. But I am good now and will update this soon, Thanks for the patience.

@ayushsatyam146
Copy link
Contributor Author

@heheda12345, I tried to address all your concerns. Can you please take a look now, Thanks!

@mergify
Copy link

mergify bot commented Oct 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ayushsatyam146.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345 I resolved the merge conflicts on this and also included the changes suggested by you. Please take a look, Thanks!

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation looks great!

@heheda12345
Copy link
Collaborator

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

* Replace ref_cnt-based common prefix detection with running request tracking
* Update get_num_common_prefix_blocks() to accept running_request_ids set
* Fix FullAttentionManager to count actual references from running requests
* Prevent incorrect cascade attention when async KV offloading delays cleanup

This resolves a bug where completed requests with pending async transfers
still contributed to ref_cnt, causing incorrect cascade attention decisions.

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
@ayushsatyam146
Copy link
Contributor Author

Hi @heheda12345, I went through this approach, and apart from some occasional conservative handling of cascade attention, it looks good overall. I’ve implemented it as well — please take a look when you get a chance. Thanks!

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I think this solution is clean.

@heheda12345 heheda12345 enabled auto-merge (squash) October 8, 2025 02:55
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 8, 2025
@heheda12345 heheda12345 merged commit cd98905 into vllm-project:main Oct 8, 2025
46 checks passed
mrasquinha-g pushed a commit to mrasquinha-g/vllm that referenced this pull request Oct 9, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…llm-project#23485)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Dhruvilbhatt pushed a commit to Dhruvilbhatt/vllm that referenced this pull request Oct 14, 2025
…llm-project#23485)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
Signed-off-by: Dhruvil Bhatt <bhattdbh@amazon.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Oct 24, 2025
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
vllm-project/vllm@17c540a

1. refactor deepseek to the latest code arch as of
vllm-project/vllm@17c540a
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
vllm-project/vllm#25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by vllm-project/vllm#26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
vllm-project/vllm#23485
- Fix `MLAAttention` import,caused by
vllm-project/vllm#25103
- Fix `SharedFusedMoE` import, caused by
vllm-project/vllm#26145
- Fix `LazyLoader` improt, caused by
vllm-project/vllm#27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
vllm-project/vllm#26990
- Fix `Backend` enum import, caused by
vllm-project/vllm#25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by vllm-project/vllm#26355
- Fix fused_moe ops, caused by
vllm-project/vllm#24097
- Fix bert model because of `inputs_embeds`, caused by
vllm-project/vllm#25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
vllm-project/vllm#24172
- Fix `splitting_ops` changes introduced by
vllm-project/vllm#25845
- Fix multi-modality changes introduced by
vllm-project/vllm#16229
- Fix lora bias dropping issue introduced by
vllm-project/vllm#25807
- Fix structured ouput break introduced by
vllm-project/vllm#26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…llm-project#23485)

Signed-off-by: Ayush Satyam <ayushsatyam146@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants