fix: enable FA3 for SM80+ GPUs and fix CUDA version comparison#4591
Open
windreamer wants to merge 6 commits into
Open
fix: enable FA3 for SM80+ GPUs and fix CUDA version comparison#4591windreamer wants to merge 6 commits into
windreamer wants to merge 6 commits into
Conversation
…ng on CUDA - Update use_fa3 capability check from == 9 (SM90 only) to >= 8 (SM80+) in attention/__init__.py and configurations/utils.py - Add FA3 requirement check in graph_runner.py: speculative decoding on CUDA now raises a clear error if FA3 is unavailable, instead of crashing deep in the Triton paged attention kernel - Update docstrings/error messages to reflect SM80+ (Ampere) support
…tion FA3 mha_fwd derives seqlen_k from page_table.shape[1] * page_size for paged KV without cu_seqlens_k. get_scheduler_metadata must receive the same value to produce a consistent scheduler layout. Previously max_seqlen_k was incorrectly set to step_context. max_kv_seqlen (runtime KV length) in op_backend.py, and decode_query_len or attn_metadata.max_kv_seqlen in cudagraph.py. These values differ from what FA3 computes internally, causing scheduler_metadata to be misaligned with the actual kernel behavior. - op_backend.py: use block_offsets.size(1) * block_size - cudagraph.py: use graph_meta.num_blocks * graph_meta.block_size Both now match FA3 internal: page_table.size(1) * page_size. Co-authored-by: openhands <openhands@all-hands.dev>
- Add RuntimeError in op_backend.py update_step_context when speculative decoding is used without FA3, matching the existing check in graph_runner.py. Previously only the CUDA Graph path validated FA3 availability; the eager path would fail with an unhelpful ImportError. - Remove unused flash_attn_v3_available() from configurations/utils.py. The canonical FA3 availability check lives in attention/__init__.py (use_fa3 module variable) which is already imported by graph_runner.py.
String comparison '12.10' >= '12.3' evaluates to False because it compares character by character. Switch to tuple-of-ints comparison to correctly handle minor versions with different digit counts.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR updates the CUDA backend to correctly enable and use FlashAttention-3 (FA3) on SM80+ GPUs (Ampere and newer) with CUDA ≥ 12.3, and fixes FA3 scheduler-metadata sizing for paged-KV speculative decoding to avoid buffer mismatches/correctness issues.
Changes:
- Relax FA3 GPU capability gating to SM80+ and fix CUDA version comparison to use numeric parsing (not string compare).
- Fix FA3
max_seqlen_ksizing to match FA3’s internal paged-KV computation (page_table.size(1) * page_size) across eager + CUDA Graph paths. - Add explicit FA3 availability checks/errors for speculative decoding paths; remove unused FA3 availability helper.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| lmdeploy/pytorch/models/utils/cudagraph.py | Adjust FA3 scheduler metadata buffer sizing in CUDA Graph capture/replay to use num_blocks * block_size. |
| lmdeploy/pytorch/configurations/utils.py | Remove unused flash_attn_v3_available() helper. |
| lmdeploy/pytorch/backends/cuda/op_backend.py | Compute FA3 max_seqlen_k from block_offsets shape; add eager speculative-decoding FA3 availability error. |
| lmdeploy/pytorch/backends/cuda/graph_runner.py | Add early FA3 requirement check for CUDA Graph runner when using speculative decoding. |
| lmdeploy/pytorch/backends/cuda/attention/fa3.py | Update docstring to reflect SM80+ support. |
| lmdeploy/pytorch/backends/cuda/attention/init.py | Enable FA3 on SM80+ and fix CUDA version comparison logic. |
Comments suppressed due to low confidence (1)
lmdeploy/pytorch/backends/cuda/op_backend.py:241
- The raised error claims the "Current GPU does not meet these requirements", but
use_fa3can also be false due to missing/failedflash-attninstall or missingtorch.ops.flash_attn_3. Consider rewording to avoid attributing the failure solely to GPU capability, and ideally include detected SM/CUDA version (or that flash-attn is not installed) to make the error actionable.
from .attention import use_fa3
if not use_fa3:
raise RuntimeError(
'Speculative decoding on CUDA requires FlashAttention-3 (FA3), '
'which is available on SM80+ GPUs (Ampere architecture and above) '
'with CUDA >= 12.3. Current GPU does not meet these requirements. '
'Please use a SM80+ GPU with CUDA >= 12.3 and install flash-attn, '
'or disable speculative decoding.')
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Use _normalize_sliding_window() instead of raw tuple multiply in op_backend.py update_meta_flashattn, consistent with cudagraph.py and attention/__init__.py handling of None/int/tuple sliding_window. - Improve FA3 RuntimeError messages in both op_backend.py and graph_runner.py: include detected SM version and CUDA version, and avoid misleadingly attributing failure solely to GPU capability (flash-attn may simply not be installed).
grimoire
reviewed
May 18, 2026
DeepSeek MTP models use FlashMLA instead of FA3 for speculative decoding. The FA3 requirement check and use_fa3_decoding flag should not trigger when use_flash_mla is enabled, matching the logic in CudaOpsBackend.update_step_context where FlashMLA takes priority.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
FlashAttention-3 now supports SM80+ (Ampere and above) GPUs with CUDA >= 12.3, not just SM90 (Hopper). The existing code had several issues preventing proper FA3 usage on Ampere GPUs and could produce incorrect results or unhelpful errors.
Modification
1. Enable FA3 for SM80+ GPUs (
attention/__init__.py)== 9(Hopper-only) to>= 8(Ampere and above), matching the current flash-attention support matrix.torch.version.cuda >= "12.3"used string comparison, which incorrectly evaluates"12.10" < "12.3". Replaced with numeric tuple comparison(12, 10) >= (12, 3).2. Align
max_seqlen_kwith FA3 internal computation (op_backend.py,cudagraph.py)mha_fwdinternally computesseqlen_k = page_table.size(1) * page_sizewhencu_seqlens_kis None (paged KV without varlen_k). The scheduler metadata buffer must be sized to match, otherwise buffer size mismatches cause incorrect results.update_meta_flashattn,update_meta_flashattn_decoding, and bothCudaGraphMixinpaths) to usenum_blocks * block_sizeinstead ofmax_kv_seqlen.3. Add FA3 availability check in eager path for speculative decoding (
op_backend.py,graph_runner.py)graph_runner.pyalready validates FA3 availability at CUDA Graph init time, but the eager mode path (throughupdate_step_context) had no such check. Without FA3, speculative decoding would fail with an unhelpfulImportErrordeep insideupdate_meta_flashattn.RuntimeErrorwith a clear message inop_backend.py:update_step_contextwhenmodel_paradigm == "ar_spec"and FA3 is unavailable, matching the graph runner check.4. Normalize
sliding_windowinupdate_meta_flashattn(op_backend.py)(sliding_window,) * 2with_normalize_sliding_window(), consistent with howcudagraph.pyandattention/__init__.pyhandleNone/int/tuplesliding_window values. This prevents producing invalid(None, None)or nested tuples.5. Remove dead code (
configurations/utils.py)flash_attn_v3_available()which was only defined but never called anywhere. The canonical FA3 availability check is theuse_fa3module variable inattention/__init__.py.BC-breaking (Optional)
No BC-breaking changes. The SM capability threshold is relaxed (SM90 → SM80+), which only enables FA3 on more hardware. The
max_seqlen_kfix corrects a correctness issue. The eager-mode check raises an explicit error where previously an opaqueImportErrorwould occur.Checklist