Skip to content

fix: enable FA3 for SM80+ GPUs and fix CUDA version comparison#4591

Open
windreamer wants to merge 6 commits into
InternLM:mainfrom
windreamer:ampere-fa3
Open

fix: enable FA3 for SM80+ GPUs and fix CUDA version comparison#4591
windreamer wants to merge 6 commits into
InternLM:mainfrom
windreamer:ampere-fa3

Conversation

@windreamer
Copy link
Copy Markdown
Collaborator

@windreamer windreamer commented May 18, 2026

Motivation

FlashAttention-3 now supports SM80+ (Ampere and above) GPUs with CUDA >= 12.3, not just SM90 (Hopper). The existing code had several issues preventing proper FA3 usage on Ampere GPUs and could produce incorrect results or unhelpful errors.

Modification

1. Enable FA3 for SM80+ GPUs (attention/__init__.py)

  • Changed SM capability check from == 9 (Hopper-only) to >= 8 (Ampere and above), matching the current flash-attention support matrix.
  • Fixed CUDA version comparison: torch.version.cuda >= "12.3" used string comparison, which incorrectly evaluates "12.10" < "12.3". Replaced with numeric tuple comparison (12, 10) >= (12, 3).

2. Align max_seqlen_k with FA3 internal computation (op_backend.py, cudagraph.py)

  • FA3's mha_fwd internally computes seqlen_k = page_table.size(1) * page_size when cu_seqlens_k is None (paged KV without varlen_k). The scheduler metadata buffer must be sized to match, otherwise buffer size mismatches cause incorrect results.
  • Updated all call sites (update_meta_flashattn, update_meta_flashattn_decoding, and both CudaGraphMixin paths) to use num_blocks * block_size instead of max_kv_seqlen.

3. Add FA3 availability check in eager path for speculative decoding (op_backend.py, graph_runner.py)

  • graph_runner.py already validates FA3 availability at CUDA Graph init time, but the eager mode path (through update_step_context) had no such check. Without FA3, speculative decoding would fail with an unhelpful ImportError deep inside update_meta_flashattn.
  • Added a RuntimeError with a clear message in op_backend.py:update_step_context when model_paradigm == "ar_spec" and FA3 is unavailable, matching the graph runner check.
  • Error messages include detected SM version and CUDA version for easier diagnosis, and correctly attribute failure to either missing flash-attn installation or insufficient GPU capability.

4. Normalize sliding_window in update_meta_flashattn (op_backend.py)

  • Replaced raw tuple multiply (sliding_window,) * 2 with _normalize_sliding_window(), consistent with how cudagraph.py and attention/__init__.py handle None/int/tuple sliding_window values. This prevents producing invalid (None, None) or nested tuples.

5. Remove dead code (configurations/utils.py)

  • Removed flash_attn_v3_available() which was only defined but never called anywhere. The canonical FA3 availability check is the use_fa3 module variable in attention/__init__.py.

BC-breaking (Optional)

No BC-breaking changes. The SM capability threshold is relaxed (SM90 → SM80+), which only enables FA3 on more hardware. The max_seqlen_k fix corrects a correctness issue. The eager-mode check raises an explicit error where previously an opaque ImportError would occur.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. (FA3 availability depends on GPU hardware and flash-attn installation; changes are validated by pre-commit and manual testing on target hardware.)
  3. If the modification has a dependency on downstream projects of a newer version, this PR should be tested with all supported versions of downstream projects.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

windreamer and others added 4 commits May 18, 2026 11:18
…ng on CUDA

- Update use_fa3 capability check from == 9 (SM90 only) to >= 8 (SM80+)
  in attention/__init__.py and configurations/utils.py
- Add FA3 requirement check in graph_runner.py: speculative decoding
  on CUDA now raises a clear error if FA3 is unavailable, instead of
  crashing deep in the Triton paged attention kernel
- Update docstrings/error messages to reflect SM80+ (Ampere) support
…tion

FA3 mha_fwd derives seqlen_k from page_table.shape[1] * page_size
for paged KV without cu_seqlens_k. get_scheduler_metadata must
receive the same value to produce a consistent scheduler layout.

Previously max_seqlen_k was incorrectly set to step_context.
max_kv_seqlen (runtime KV length) in op_backend.py, and
decode_query_len or attn_metadata.max_kv_seqlen in cudagraph.py.
These values differ from what FA3 computes internally, causing
scheduler_metadata to be misaligned with the actual kernel behavior.

- op_backend.py: use block_offsets.size(1) * block_size
- cudagraph.py: use graph_meta.num_blocks * graph_meta.block_size

Both now match FA3 internal: page_table.size(1) * page_size.

Co-authored-by: openhands <openhands@all-hands.dev>
- Add RuntimeError in op_backend.py update_step_context when speculative
  decoding is used without FA3, matching the existing check in
  graph_runner.py. Previously only the CUDA Graph path validated FA3
  availability; the eager path would fail with an unhelpful ImportError.
- Remove unused flash_attn_v3_available() from configurations/utils.py.
  The canonical FA3 availability check lives in attention/__init__.py
  (use_fa3 module variable) which is already imported by graph_runner.py.
String comparison '12.10' >= '12.3' evaluates to False because it
compares character by character. Switch to tuple-of-ints comparison
to correctly handle minor versions with different digit counts.
Copilot AI review requested due to automatic review settings May 18, 2026 03:29
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the CUDA backend to correctly enable and use FlashAttention-3 (FA3) on SM80+ GPUs (Ampere and newer) with CUDA ≥ 12.3, and fixes FA3 scheduler-metadata sizing for paged-KV speculative decoding to avoid buffer mismatches/correctness issues.

Changes:

  • Relax FA3 GPU capability gating to SM80+ and fix CUDA version comparison to use numeric parsing (not string compare).
  • Fix FA3 max_seqlen_k sizing to match FA3’s internal paged-KV computation (page_table.size(1) * page_size) across eager + CUDA Graph paths.
  • Add explicit FA3 availability checks/errors for speculative decoding paths; remove unused FA3 availability helper.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
lmdeploy/pytorch/models/utils/cudagraph.py Adjust FA3 scheduler metadata buffer sizing in CUDA Graph capture/replay to use num_blocks * block_size.
lmdeploy/pytorch/configurations/utils.py Remove unused flash_attn_v3_available() helper.
lmdeploy/pytorch/backends/cuda/op_backend.py Compute FA3 max_seqlen_k from block_offsets shape; add eager speculative-decoding FA3 availability error.
lmdeploy/pytorch/backends/cuda/graph_runner.py Add early FA3 requirement check for CUDA Graph runner when using speculative decoding.
lmdeploy/pytorch/backends/cuda/attention/fa3.py Update docstring to reflect SM80+ support.
lmdeploy/pytorch/backends/cuda/attention/init.py Enable FA3 on SM80+ and fix CUDA version comparison logic.
Comments suppressed due to low confidence (1)

lmdeploy/pytorch/backends/cuda/op_backend.py:241

  • The raised error claims the "Current GPU does not meet these requirements", but use_fa3 can also be false due to missing/failed flash-attn install or missing torch.ops.flash_attn_3. Consider rewording to avoid attributing the failure solely to GPU capability, and ideally include detected SM/CUDA version (or that flash-attn is not installed) to make the error actionable.
                from .attention import use_fa3
                if not use_fa3:
                    raise RuntimeError(
                        'Speculative decoding on CUDA requires FlashAttention-3 (FA3), '
                        'which is available on SM80+ GPUs (Ampere architecture and above) '
                        'with CUDA >= 12.3. Current GPU does not meet these requirements. '
                        'Please use a SM80+ GPU with CUDA >= 12.3 and install flash-attn, '
                        'or disable speculative decoding.')

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread lmdeploy/pytorch/backends/cuda/op_backend.py
Comment thread lmdeploy/pytorch/backends/cuda/graph_runner.py Outdated
- Use _normalize_sliding_window() instead of raw tuple multiply in
  op_backend.py update_meta_flashattn, consistent with cudagraph.py
  and attention/__init__.py handling of None/int/tuple sliding_window.
- Improve FA3 RuntimeError messages in both op_backend.py and
  graph_runner.py: include detected SM version and CUDA version, and
  avoid misleadingly attributing failure solely to GPU capability
  (flash-attn may simply not be installed).
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Comment thread lmdeploy/pytorch/backends/cuda/graph_runner.py Outdated
Comment thread lmdeploy/pytorch/backends/cuda/graph_runner.py Outdated
DeepSeek MTP models use FlashMLA instead of FA3 for speculative
decoding. The FA3 requirement check and use_fa3_decoding flag should
not trigger when use_flash_mla is enabled, matching the logic in
CudaOpsBackend.update_step_context where FlashMLA takes priority.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants