Skip to content

Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support #11844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 13, 2025

Conversation

sighingnow
Copy link
Collaborator

@sighingnow sighingnow commented Jan 8, 2025

This PR implements the dual-chunk flash attention, a training-free method to extend model context length (see also #6139), with sparse attention (https://github.com/microsoft/MInference) support.

This PR requires the sparse attention kernel from vllm-flash-attention. Qwen models with 1m context length support will be open-sourced in the next one or two weeks, and unit tests will be added later.

FIX #12452

Copy link

github-actions bot commented Jan 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jan 8, 2025
@sighingnow sighingnow force-pushed the dev/dual-chunk-attn branch 2 times, most recently from 82b5a4c to 4c4a33e Compare January 9, 2025 06:17
@jacob-crux
Copy link

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph.
Do you plan to fix this in the future?

Copy link

mergify bot commented Jan 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sighingnow
Copy link
Collaborator Author

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph. Do you plan to fix this in the future?

All conflicts fixed, could you please take another look? thanks!

st] = decode_metadata.block_tables[i, st:ed]
decode_metadata.block_tables_intra = block_tables_intra

seq_lens_succ = (chunk_num_curr -

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.
I modified the code as below and confirmed that it works.

seq_lens_succ = ((chunk_num_curr - (chunk_num_curr - 1).clip(min=0)) * chunk_len)

Copy link

mergify bot commented Jan 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 15, 2025
@jacob-crux
Copy link

I see that you have enforce_eager=True set, so it looks like there are still compatibility issues with cudagraph. Do you plan to fix this in the future?

All conflicts fixed, could you please take another look? thanks!

I tested it because I thought it was fixed, but I still have the same problem as below.
Are you saying that Cudagraph capture is possible? (enforce_eager=False)

Capturing CUDA graph shapes:   0%|                                                                                                                                                                                                               | 0/35 [00:00<?, ?it/s]
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data/lme-storage_810/jacob/needle/NeedleInAHaystack-lme/run_needle_in_haystack.py", line 435, in <module>
[rank0]:     ht = LLMNeedleHaystackTester(
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data/lme-storage_810/jacob/needle/NeedleInAHaystack-lme/run_needle_in_haystack.py", line 94, in __init__
[rank0]:     self.model_to_test = LLM(model=model_name)
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/utils.py", line 1044, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/entrypoints/llm.py", line 228, in __init__
[rank0]:     self.llm_engine = self.engine_class.from_engine_args(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 517, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 276, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/engine/llm_engine.py", line 429, in _initialize_kv_caches
[rank0]:     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/executor/gpu_executor.py", line 83, in initialize_cache
[rank0]:     self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/worker.py", line 274, in initialize_cache
[rank0]:     self._warm_up_model()
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/worker.py", line 292, in _warm_up_model
[rank0]:     self.model_runner.capture_model(self.gpu_cache)
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/model_runner.py", line 1533, in capture_model
[rank0]:     graph_runner.capture(**capture_inputs)
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/worker/model_runner.py", line 1885, in capture
[rank0]:     self.model(
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 496, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/compilation/decorators.py", line 170, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 359, in forward
[rank0]:     hidden_states, residual = layer(
[rank0]:                               ^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 267, in forward
[rank0]:     hidden_states = self.self_attn(
[rank0]:                     ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/model_executor/models/qwen2.py", line 189, in forward
[rank0]:     attn_output = self.attn(q,
[rank0]:                   ^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/layer.py", line 185, in forward
[rank0]:     return torch.ops.vllm.unified_attention(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/layer.py", line 280, in unified_attention
[rank0]:     return self.impl.forward(query, key, value, kv_cache, attn_metadata,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/bc-user/vllm_dual_chunk_250114/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 373, in forward
[rank0]:     assert decode_meta.scaling_factor is not None
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError

@mergify mergify bot removed the needs-rebase label Jan 16, 2025
@sighingnow
Copy link
Collaborator Author

I tested it because I thought it was fixed, but I still have the same problem as below.
Are you saying that Cudagraph capture is possible? (enforce_eager=False)

The dual chunk attention doesn't support cuda graph and I have added an assertion in arg_utils.py.

When I try the Needle in a haystack test with qwen-7b and llama-8b(Modified code to support llama), there is a bug that produces a negative number when it is over 13k~15k.

It is indeed a bug introduced during preparing this PR, fixed. Thanks!

@sighingnow
Copy link
Collaborator Author

sighingnow commented Jan 19, 2025

Rebase against main.

Hi @youkaichao @simon-mo @WoosukKwon Do you folks think if there are still things that need to be improved in this pull request?

Thanks!

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spotted a few bits ofcommented out code that look like debug cruft or are otherwise mysterious. Could you clean those up and any other similar spots?

Copy link

mergify bot commented Jan 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @sighingnow.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 20, 2025
qc_freqs = torch.einsum("i,j -> ij", qc_t, inv_freq)
k_freqs = torch.einsum("i,j -> ij", k_t, inv_freq)
qc_no_clamp_freqs = torch.einsum("i,j -> ij", qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.einsum("i,j -> ij", q_inter_t, inv_freq)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think these einsum's are still slow on cuda than (a * b).sum(-1), not on the hot path though so not critical

pytorch/pytorch#101249

ran bench_einsum.py from that issue on an H100 and got:

python einsum_bench.py 
[-------------------------------------  -------------------------------------]
                                  |  mul/sum  |  torch.einsum  |  numpy.einsum
1 threads: -------------------------------------------------------------------
      Nc,Nc->N cpu (1048576, 2)   |    5000   |      3100      |      4000    
      Nc,Nc->N cuda (1048576, 2)  |      20   |       747      |      3300    

Times are in microseconds (us).

Comment on lines 102 to 115
logits_soft_cap, attn_type, **{
"dual_chunk_attention_config": dual_chunk_attention_config,
"prefix": prefix,
} if dual_chunk_attention_config is not None else {})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this messy, I think we should maybe do something like:

def __init__(..., **extra_attn_kwargs):
   self.impl = impl_cls(..., **extra_attn_kwargs)

the challenge here is prefix would not be captured by extra_attn_kwargs but is only (currently) used by DualChunkFlashAttentionImpl. I do think it would be less messy though to do this any make prefix a standard arg for attention impls, given that it is pretty generic. Thoughts @WoosukKwon

Comment on lines 148 to 158
if self.dual_chunk_attention_config:
assert query_succ_and_inter is not None
dca_kwargs = {
"query_succ": query_succ_and_inter[0],
"query_inter": query_succ_and_inter[1],
"query_succ_critical": query_succ_and_inter[2],
"query_inter_critical": query_succ_and_inter[3],
} if query_succ_and_inter else {}
else:
dca_kwargs = {}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should try hard to see if there is cleaner way of passing these, maybe they can be bundled into a single q tensor that get reinterpreted as components via a combination of slicing and .view calls in the attn impl?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would take a try to see if it can be simplified.

@simon-mo simon-mo merged commit 60f7624 into vllm-project:main May 13, 2025
86 of 91 checks passed
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
@sighingnow sighingnow deleted the dev/dual-chunk-attn branch May 16, 2025 08:34
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…h sparse attention support (vllm-project#11844)

Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
@ExtReMLapin
Copy link

So If I understand correctly, now Qwen2.5-1M actually uses the correct attention mechanism and VRAM should be lowered and prompt processing faster, right ?

@exceedzhang
Copy link

I tested Qwen/Qwen2.5-7B-Instruct-1M using DualChunkFlashAttention backend.
image
It startup well, but not work well. @sighingnow

ubuntu-vllm-openai-1 | INFO 05-31 19:13:07 [logger.py:42] Received request cmpl-77d91882816c4f748e2023c93449f62d-0: prompt: 'Once upon a time', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.05, temperature=0.0, top_p=1.0, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=1000, min_tokens=0, logprobs=1, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: [12522, 5193, 264, 882], prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
ubuntu-vllm-openai-1 | INFO 05-31 19:13:07 [engine.py:316] Added request cmpl-77d91882816c4f748e2023c93449f62d-0.
ubuntu-vllm-openai-1 | INFO: 172.18.0.1:46884 - "POST /v1/completions HTTP/1.1" 500 Internal Server Error
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] AssertionError('seqused_k must be provided if block_table is provided')
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] Traceback (most recent call last):
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 162, in start
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] self.run_engine_loop()
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 225, in run_engine_loop
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] request_outputs = self.engine_step()
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 251, in engine_step
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] raise e
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 234, in engine_step
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self.engine.step()
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 1393, in step
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] outputs = self.model_executor.execute_model(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 299, in execute_model
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] driver_outputs = self._driver_execute_model(execute_model_req)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/mp_distributed_executor.py", line 144, in _driver_execute_model
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self.driver_worker.execute_model(execute_model_req)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker_base.py", line 420, in execute_model
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] output = self.model_runner.execute_model(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return func(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1843, in execute_model
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] hidden_or_intermediate_states = model_executable(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self._call_impl(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return forward_call(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2.py", line 481, in forward
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] hidden_states = self.model(input_ids, positions, intermediate_tensors,
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 172, in call
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self.forward(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2.py", line 358, in forward
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] hidden_states, residual = layer(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self._call_impl(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return forward_call(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2.py", line 257, in forward
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] hidden_states = self.self_attn(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self._call_impl(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return forward_call(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/qwen2.py", line 187, in forward
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] attn_output = self.attn(q, k, v)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self._call_impl(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return forward_call(*args, **kwargs)
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 237, in forward
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return torch.ops.vllm.unified_attention(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1158, in call
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] return self._op(*args, **(kwargs or {}))
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/attention/layer.py", line 386, in unified_attention
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] output = self.impl.forward(self, query, key, value, kv_cache,
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 493, in forward
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] self._dual_chunk_flash_attn_prefill(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 673, in _dual_chunk_flash_attn_prefill
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] current_out = self._dual_chunk_flash_attn_prefill_func(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 1055, in _dual_chunk_flash_attn_prefill_func
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] flash_result = self._do_flash_attn(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 1207, in _do_flash_attn
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] output, softmax_lse = flash_attn_varlen_func(
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] File "/usr/local/lib/python3.12/dist-packages/vllm/vllm_flash_attn/flash_attn_interface.py", line 204, in flash_attn_varlen_func
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] assert block_table is None or seqused_k is not None,
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ubuntu-vllm-openai-1 | ERROR 05-31 19:13:08 [engine.py:164] AssertionError: seqused_k must be provided if block_table is provided

@ExtReMLapin
Copy link

Exact same issue as above

@ExtReMLapin
Copy link

ExtReMLapin commented Jun 5, 2025

PR #19084 Fixes this issue.

When working with contexts of 70k, with the model loaded + the context it uses something like 30Gb of vram, but during inference it goes up to 35-37gb of vram then back down to 30Gb.

I'm guessing it's expected but is there some kind of way to preallocating this memory ? Because if you let VLLM allocate 80% of the vram and it tries to "eat" more VRAM, well obviously it will OOM

Edit :

  • FP8 model quantization is not working
  • --pipeline_parallel_size is not working
  • --tensor_parallel_size is not working

@mklasby
Copy link

mklasby commented Jun 5, 2025

@ExtReMLapin

When working with contexts of 70k, with the model loaded + the context it uses something like 30Gb of vram, but during inference it goes up to 35-37gb of vram then back down to 30Gb.

The qk estimate softmax has high memory overhead: https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/dual_chunk_flash_attn.py#L834

During start-up profiling, DCA specifically routes to flash-attention instead of the DCA sparse prefill function:
https://github.com/vllm-project/vllm/blob/main/vllm/attention/backends/dual_chunk_flash_attn.py#L474

In principle, there's no reason to use flash-attention during profiling from what I can see. So having that branch instead call the sparse attention branch should at least identify the OOM during profiling.

@ExtReMLapin
Copy link

Not a "blog" but it can help people working with it, so far we got much better results with Qwen 2.5 7b 1m than with nemotron 4M from nvidia.

However beside the issues states before with quantization and gpu splitting, we did not manage either to do batching/parallel processing

@ExtReMLapin
Copy link

Quantization support has been added into #19420

Could not test kv cache quantization because this attention mechanism is based on Flash attention

minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…h sparse attention support (vllm-project#11844)

Signed-off-by: minpeter <kali2005611@gmail.com>
@ExtReMLapin
Copy link

ExtReMLapin commented Jul 4, 2025

A bug appeared between commit bbfa0c6 and b9a1791 that makes vllm serve crash when using DCA

downtracking it...

edit : ref #20484 20484

@sighingnow
Copy link
Collaborator Author

Thanks for reporting @ExtReMLapin @exceedzhang . Will investigate this week.

@ExtReMLapin
Copy link

It’s already fixed and a PR has been merged.

@ExtReMLapin
Copy link

@sighingnow @exceedzhang thank for your contributions, it's mostly those PR that needs a review :

Priority because crash fix : #19084

FP8 quantization support #19420

@exceedzhang
Copy link

FP8 quantization support #19420

Thank you for your development work; I've tested it, and the feature functions correctly. However, I've noticed a performance drop after enabling FP8 quantization.

Here are the performance test results using four RTX 4090 24GB GPUs.
WX20250710-215113@2x

FP8 quantization
WX20250710-220233@2x

@ExtReMLapin
Copy link

I agree with you, we expect better performance with FP8 because of lower memory bottleneck.

I also have another update waiting under the hood on this branch which should improve performances (packed torch operations) :
https://github.com/ExtReMLapin/vllm/tree/faster_dca but I didn't run enough tests on it.

Considering the slow downs ... isn't that the fault of the flash attention implementation considering the very little changes I did ?

@exceedzhang
Copy link

I agree with you, we expect better performance with FP8 because of lower memory bottleneck.

I also have another update waiting under the hood on this branch which should improve performances (packed torch operations) : https://github.com/ExtReMLapin/vllm/tree/faster_dca but I didn't run enough tests on it.

Considering the slow downs ... isn't that the fault of the flash attention implementation considering the very little changes I did ?

@ExtReMLapin
Thanks for optimizing the code, but I've tested it and the performance difference compared to the previous version isn't significant! I conducted stress tests on Qwen2.5-7B-1M using an RTX 4090 24G 4-card GPU server.

WX20250710-232341@2x

@ExtReMLapin
Copy link

Got it, not merging this performance branch into the FP8 branch then, it's not worth the risk of breaking something !

Again at the office we really appreciate the effort spent on releasing those models.

We can a lot of tests, including other models claiming to have long context :

  • Llama-3.1-Nemotron-8B-UltraLong-1M-Instruct : Not following instructions correctly
  • gradientai/Llama-3-8B-Instruct-262k following instructions but struggles to speak anything else than english
  • 01-ai/Yi-9B-200K byebye template chat being broken
  • phi-3 128k not enough vram for 128k context
  • Menlo/Jan-nano-128k really meh result, not following instructions correctly
  • aws-prototyping/MegaBeam-Mistral-7B-512k same issues as above

but this is the ONLY model actually following instructions on very long context and that can be ran easily (not insane resources).

Looking forward to see more models like this in the future !

@ExtReMLapin
Copy link

Well I'm not sure exactly what happened but reading my PR code again and again , it should only affect KV cache quantization and not model quantization, now checking again without my pr , quantization seems to work, without my changes, making this comment sounds like i'm insane #19084 (comment)

@ExtReMLapin
Copy link

Ran more tests :

  • FP8 works no need for a PR
  • tensor parallel works
  • ⚠️ DCA/Sparse attention is broken on Blackwell, fixed by this PR : Sparse attention : Generalize arch checks for A100 and above flash-attention#73
  • ⚠️ Memory grows instead of being preallocated which causes an OOM error with default allocation percentage
    • tested on RTX 5090s : CUDA_VISIBLE_DEVICES=0,1,2 VLLM_ATTENTION_BACKEND=DUAL_CHUNK_FLASH_ATTN vllm serve Qwen/Qwen2.5-7B-Instruct-1M --max-model-len 240000 --max-num-seqs 2 --port 2483 --enforce-eager --enable-server-load-tracking --disable-log-requests --max-num-seqs 50 --quantization fp8 --tensor-parallel-size 2
    • no issue with vram limitation at 80% --gpu-memory-utilization 0.8 and context of 240000 filled at 90% in one query

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Support Qwen/Qwen2.5-14B-Instruct-1M