Skip to content

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Jun 2, 2025

This PR fixes two bugs:

  1. If a model contains both sliding window attention and full attention, the sliding window attention layers are regarded as full attention layers when allocating kv cache, but in check_enough_kv_cache_memory, they are still regarded as sliding window attention layers. This PR fix it by changing the order of unify_hybrid_kv_cache_specs and check_enough_kv_cache_memory so that we first change sliding window layers to full attention layers, and then perform the check.

  2. The max concurrency estimation in _get_kv_cache_config_uniform_type didn't considered models with sliding window layers. This PR fixes it.

  • vllm serve bigcode/starcoder2-3b (a model that all layers use sliding window attention, result changed)

Main branch:

INFO 06-02 08:17:03 [kv_cache_utils.py:638] GPU KV cache size: 2,198,512 tokens
INFO 06-02 08:17:03 [kv_cache_utils.py:641] Maximum concurrency for 16,384 tokens per request: 134.19x

This PR:

INFO 06-02 08:11:22 [kv_cache_utils.py:675] GPU KV cache size: 2,198,512 tokens
INFO 06-02 08:11:22 [kv_cache_utils.py:679] Maximum concurrency for 16,384 tokens per request: 178.68x
  • vllm serve meta-llama/Llama-3.1-8B-Instruct (a model that all layers use full attention, result not changed)

Main branch:

INFO 06-02 08:18:09 [kv_cache_utils.py:638] GPU KV cache size: 415,248 tokens
INFO 06-02 08:18:09 [kv_cache_utils.py:641] Maximum concurrency for 131,072 tokens per request: 3.17x

This PR:

INFO 06-02 08:13:24 [kv_cache_utils.py:675] GPU KV cache size: 415,248 tokens
INFO 06-02 08:13:24 [kv_cache_utils.py:679] Maximum concurrency for 131,072 tokens per request: 3.17x

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link

github-actions bot commented Jun 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 2, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 3, 2025
@WoosukKwon WoosukKwon enabled auto-merge (squash) June 3, 2025 16:14
@WoosukKwon WoosukKwon merged commit a8da78e into vllm-project:main Jun 4, 2025
71 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants