[New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes#41834
[New Model][Nvidia] Add SM12x support for DeepSeek V4 Flash with essential fixes#41834jasl wants to merge 54 commits into
Conversation
|
@zyongye |
There was a problem hiding this comment.
Code Review
This pull request implements support for DeepSeek V4 on SM12x (Blackwell) architectures by providing Triton-based fallbacks for DeepGEMM-dependent operations. Key enhancements include the introduction of specialized Triton kernels for sparse MLA, FP8 einsum, and MQA logits, as well as memory optimizations in the sparse attention indexer to compute top-k indices without materializing full logits. Additionally, the PR updates the model loader to support weight name filtering for skipping MTP weights and handles Blackwell-specific FP8 quantization scales. I have no feedback to provide.
💡 Codex Reviewvllm/vllm/model_executor/layers/sparse_attn_indexer.py Lines 86 to 89 in 9596dbf This helper now disables the DeepGEMM requirement for every SM120 run, but the FP4 indexer cache path still depends on DeepGEMM kernels ( vllm/vllm/model_executor/model_loader/default_loader.py Lines 236 to 240 in 9596dbf The new pre-load ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
042e366 to
df2e6f8
Compare
…date loop Redesigned suggestion #3 from PR vllm-project#41834 comment 4450901180. The first attempt (e34daef, reverted; later 72a5ff2, also reverted) tried to truncate ``topk_indices.shape[1]`` in Python so the captured launches iterated a narrower combined slice; that approach broke under cudagraph replay (shape baked at capture) and *also* mis-bounded — the combine kernel writes each token's combined buffer as ``[topk_len_t | swa_len_t | -1 padding]`` with SWA *immediately* following the per-token topk, so a fixed ``effective_topk`` cap cuts off the SWA portion (GSM8K dropped 25 pp on the prior attempt). The kernel already loads the per-token combined length (``valid_len = tl.load(lens_ptr + token_idx)`` for the four ``lens``- gated kernels, ``gather_len`` for the two paged kernels). The existing ``is_valid`` guard only short-circuits the *heavy* work past that length; the outer ``for candidate_idx in range(0, num_candidates)`` still pays one ``tl.load`` + branch per iter on the dead tail. Capping the loop at ``min(num_candidates, valid_len - candidate_offset)`` (clamped to 0) removes those wasted iterations while preserving the existing ``is_valid`` semantics: the iterations we now skip are exactly those the existing guard already discarded. Applied to six accumulate kernels in ``sparse_mla_kernels.py``: - ``_accumulate_gathered_attention_chunk_kernel`` - ``_accumulate_indexed_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_global_slots_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel`` [decode] - ``_accumulate_fp8ds_paged_attention_chunk_kernel`` [autotuned in #1] - ``_accumulate_fp8ds_paged_attention_chunk_multihead_kernel`` [decode] CUDA-graph safety: ``lens_ptr`` / ``gather_lens_ptr`` are stable addresses; their values are refreshed per call by the metadata builder (outside the captured forward) and by ``combine_topk_swa_indices`` (inside the forward but writing only into the persistent buffers the accumulate kernels read from). The kernel inner-loop bound is a runtime-loaded scalar — Triton compiles a dynamic loop and the captured launch picks up the current value on each replay. Savings scale with ``combined_topk_buffer_width - actual valid length`` (i.e. mostly visible at long ``max_model_len`` with shorter actual contexts). At our test shape (``max_model_len=131072``, ISL=2048) the saved iterations come mostly from the decode multihead path; expected to be neutral / no-regression at short ``max_model_len`` where the bound equals ``num_candidates``. Signed-off-by: jasl <jasl9187@hotmail.com>
Three pure comment/docstring fixes from the audit, no behavior change: 1. ``_build_c128a_topk_metadata_kernel`` comment was ambiguous about ``max_compressed_tokens`` after the parameter was renamed to ``effective_topk`` in 304944e. Reword to explicitly point at the Python caller (``build_c128a_topk_metadata``) and explain that ``max_compressed_tokens`` is the buffer column width and entries past ``effective_topk`` stay at ``-1`` via the caller's ``fill_(-1)`` pre-pass. 2. Add an inline note next to ``positions.max().item()`` flagging it as a host sync that is safe here because the builder runs outside the captured forward. 3. Expand ``MLAAttentionManager`` class docstring: the predicate ``_should_protect_prompt_blocks`` triggers on three independent conditions (DSv4 model_version, fp8_ds_mla cache_dtype_str, or compress_ratio > 1), not just DSv4. Document the three conditions inline so a future tightening pass does not accidentally narrow the coverage. Signed-off-by: jasl <jasl9187@hotmail.com>
…kip_weight_name Two refactors from the audit, no behavior change: 1. ``vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py`` had its own copy of ``_upcast_e8m0_to_fp32`` (4 lines, identical to the canonical helper at ``vllm/model_executor/layers/quantization/utils/fp8_utils. py:1017``). Other peer call sites (cutlass.py, rocm_aiter_mla_sparse. py, mxfp4.py) already import from ``fp8_utils``; do the same here. 2. ``DeepseekV4ForCausalLM.skip_weight_name_before_load`` used ``hf_to_vllm_mapper.apply_list([name])`` to map a single name. That builds a one-element list and routes through a list-comprehension that filters ``None``. Use the canonical 1-to-1 helper ``WeightsMapper._map_name`` directly, matching the pattern used in ``compressed_tensors.py``, ``adapters.py``, ``bitsandbytes_loader. py``, and ``lora/utils.py``. Same semantics, 3 lines instead of 5. Signed-off-by: jasl <jasl9187@hotmail.com>
…te kernels
After ``a94e7c289 sm12x: per-token early-loop-exit on sparse MLA
accumulate inner candidate loop`` capped each inner loop at
``local_eff = min(num_candidates, max(valid_len - candidate_offset, 0))``
(or the ``gather_len`` equivalent for the paged kernels), the per-iter
check ``(candidate_offset + candidate_idx) < valid_len`` /
``gather_idx < gather_len`` became structurally always true: by
construction every iteration's index sits inside the valid range.
This commit drops the tautological term in 7 sparse MLA accumulate
kernels and leaves the remaining cell-sentinel guard:
- ``accumulate_..._gathered_chunk`` (was: ``(...) < valid_len`` then
AND with ``slot_id >= 0`` when ``HAS_SLOT_IDS``): now just
``is_valid = slot_id >= 0`` (or ``True`` when ``HAS_SLOT_IDS`` is
false). The branch on ``HAS_SLOT_IDS`` becomes a ``tl.constexpr``
binary, which Triton compiles into two clean specialisations.
- ``accumulate_..._indexed_chunk``: ``is_valid = kv_index >= 0``.
- ``accumulate_fp8ds_global_slots_sparse_mla_attention_chunk{,_multihead}``:
``is_valid = slot_id >= 0``.
- ``accumulate_fp8ds_paged_sparse_mla_attention_chunk{,_multihead,
_multihead_with_sink}``: there is no per-cell sentinel here, so the
whole ``is_valid`` variable and ``if is_valid:`` guard go away and
the loop body becomes unconditional.
Each touched site gains a 2-3 line comment explaining the invariant so
a future reader can see why no per-iter clamp is needed. No behavioral
change: Triton was already eliminating the tautology after the SSA
pass; this commit makes the intent explicit at the source level.
Signed-off-by: jasl <jasl9187@hotmail.com>
PR vllm-project#42258 introduced SlidingWindowManager._cache_block_mask() to skip caching SWA blocks that can never serve a prefix-cache hit. When Eagle/MTP speculative decoding is active the mask is too aggressive — it skips blocks that eagle's modified lookup actually needs, resulting in 0% prefix cache hit rate. Eagle changes the SWA hit logic in two ways: 1. sliding_window_contiguous_blocks += 1 (needs one extra block) 2. post_pop_blocks = i (instead of i+1), shifting alignment Fix: detect SWA managers inside eagle attention groups at coordinator init time and disable the cache block mask for them. Signed-off-by: Alex Bilichenko <abilichenko@gmail.com> (cherry picked from commit b90c495) Signed-off-by: jasl <jasl9187@hotmail.com>
Two prefill performance fixes for SM12x DeepSeek V4: 1. Add _accumulate_indexed_attention_chunk_multihead_kernel (HEAD_BLOCK=8) that loads KV once per candidate and reuses across 8 heads, reducing L2 traffic in the prefill accumulate phase. Same pattern as the existing decode _finish_materialized_scores_with_sink_kernel. Prefill throughput on 2× RTX PRO 6000 WS, TP=2, MTP=2: - 1K tokens: +49% (2,746 → 4,100 tok/s) - 4.5K tokens: +37% (3,122 → 4,271 tok/s) - 18K tokens: +36% (2,474 → 3,360 tok/s) - 64K tokens: +28% (1,679 → 2,146 tok/s) Tuned config: HEAD_BLOCK=8, num_warps=4, num_stages=2. Benchmarked against HEAD_BLOCK=4 and num_warps=8 variants — HEAD_BLOCK=8 with num_warps=4 wins at all sizes. 2. Drop @triton.autotune from _deepseek_v4_sm12x_fp8_einsum_kernel and pin num_warps=4, num_stages=3. The autotune key included num_tokens which varies per request, causing ~200 unique keys with zero cache hits — re-benchmarking 4 configs at ~1s each on every request. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> (cherry picked from commit 9c2e7ca) Signed-off-by: jasl <jasl9187@hotmail.com>
…s × 3 variants) Tuned with Triton 3.6.0 at vllm@c92696943 on NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition, 10 M-buckets per shape (1, 2, 4, 8, 16, 32, 64, 128, 256, 512), 640-config search space filtered to BLOCK_SIZE_M >= M/8 for M >= 64. Four typical SM12x DSv4-Flash deployment shapes — none had prior tuning in the vLLM tree at this revision: E=128, N=2048 (TP=2 + EP, production shape, 2x RTX PRO 6000) E=64, N=2048 (TP=4 + EP, 4x RTX PRO 6000) — tune ~1h47m E=32, N=2048 (TP=8 + EP, 8x RTX PRO 6000) — tune ~1h00m E=256, N=1024 (TP=2 no-EP fallback, 2x RTX PRO 6000) — tune ~2h14m Aliased identically to the Max-Q Workstation Edition and Server Edition variants since they share the same silicon (GB202) as the Workstation Edition and only differ in power/form-factor envelope; copying yields the same Triton autotune optima. Signed-off-by: jasl <jasl9187@hotmail.com>
Route the direct FP8 MQA top-k fallback through the existing Triton logits kernel when the materialized logits fit within a bounded workspace, then keep the previous PyTorch path as the fallback for larger or unsupported shapes. This removes the 127K prefill PyTorch/CUTLASS logits hotspot on RTX PRO 6000 while preserving the short-context path. Local SM120 validation showed 127K C=1 cold TTFT mean improving from 60.83s to 37.65s with no short-context regression. Inspired-by: vllm-project#41834 comment 4476480477 Signed-off-by: jasl <jasl9187@hotmail.com>
Use the existing top_k_per_row_prefill CUDA op after the SM120 Triton MQA logits fallback materializes logits. The op writes int32 indices directly and respects row bounds, avoiding torch.topk plus int64-to-int32 copy in the long-context prefill path. On RTX PRO 6000 TP=2, 127K C=1 cold TTFT mean improved from 37.65s to 36.87s with no short-context regression across C=1/2/4. Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Detect CUTeDSL by importing cutlass instead of only checking package metadata. This lets SM12x deployments fall back to Triton when a broken or incompatible nvidia-cutlass-dsl install leaves the cutlass package visible but not importable. Reported-by: danielwu1987 Signed-off-by: jasl <jasl9187@hotmail.com>
Keep FULL_AND_PIECEWISE enabled for DeepSeek V4 MTP and avoid replaying small speculative decode batches against padded virtual requests by preserving exact spec-decode capture sizes for request counts 1..32. Signed-off-by: jasl <jasl9187@hotmail.com>
The model-specific TileLang warmup did not reduce startup time, first-request JIT warnings, 127K C=1 TTFT, or short C=4 correctness in the SM120 ablation. Drop the warmup hook and env knobs instead of keeping a dead A/B path. Signed-off-by: jasl <jasl9187@hotmail.com>
Widen the direct SM120 FP8 MQA logits Triton row tile from 16 to 64 while keeping BLOCK_N=128, BLOCK_D=64, and four warps. SM120 microbench improved about 22-23% across 128/256/512/1024 query rows. Same-host C=1 repeat gates improved 59K mean TTFT from 11.413s to 11.097s and 124K from 29.868s to 28.042s. Short MT-Bench C=1/2/4 and GSM8K limit-200 temperature=0 passed. Artifacts: codex_mqa_blockm_followup_20260520040105, codex_blockm64_c1_repeat/20260520041210, codex_blockm16_c1_repeat_baseline/20260520041627, codex_blockm64_short_gsm8k_gate/20260520042117. Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
The splitKV sparse MLA decode path stayed behind a default-off flag after benchmarking showed ambiguous value for the current SM120 latency target. Keep the measured matmul decode path active and preserve the experiment on backup branch codex/sm120-splitkv-decode-experiment-backup-20260521054846 for future reference. Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
Signed-off-by: jasl <jasl9187@hotmail.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Prefix-cache cliff at the stress-gate shape on
|
| filler-words | older HEAD (5b39027be) |
current HEAD (c1b41cea0) |
experimental (afb7041759) |
|---|---|---|---|
| 100 | 0% | 0% | 4% |
| 400 | 50% | 50% | 48% |
| 800 | 81% | 55% | 78% |
| 1600 | 86% | 84% | 83% |
| 3200 | 93% | 89% | 90% |
3 trials per cell, 2 concurrent sessions × 3 turns. NCCL ruled out as a cause — rebuilt all three images against nvidia-nccl-cu13 2.30.4 (vs 2.28.9 default in the base image), same numbers in both versions.
The ds4-sm120-experimental branch with da4f1c711e "Avoid releasing active prefix cache hits" recovers the 800-fw cell to 78%. The commit title literally describes the symptom. Forward-porting that commit (or whichever predecessor fixes the same thing) into the PR HEAD is the obvious next step.
Have you re-run the prefix-cache stress gate against c1b41cea0 since the latest rebase? Your consolidated matrix at 10:39 UTC yesterday was on fef36bcc. If c1b41cea0 reproduces the cliff on your machine, the experimental commit is the prime candidate. If it doesn't, something in my setup is unique and I'd like to know what.
I'm guessing rebase upstream introduces the issue. Let me take a look. |
Purpose
Enable DeepSeek V4 Flash on SM12x Blackwell consumer hardware (RTX PRO 6000 Workstation Edition, RTX 5090, DGX Spark GB10).
The core challenge: SM12x lacks the TMEM /
tcgen05instructions present on datacenter Blackwell (SM10x), so DeepGEMM, FlashMLA, and Marlin's FP8 paths fail at kernel link time on this hardware. This PR provides pure-PyTorch fallbacks, Triton kernel implementations, and SM12x-specific tuning so the model runs end-to-end with production-quality perf.Validation results
Hardware: 2x NVIDIA RTX PRO 6000 Blackwell Workstation Edition. PR head:
3424fba51. Rebased onupstream/main2026-05-19; NCCL:nvidia-nccl-cu132.30.4.Reference long-context serve config used for the 2026-05-18 run:
--no-enable-prefix-cachingis set for these latency baselines so cold prefill numbers are not biased by cache hits. End-user document/chat deployments should generally keep prefix caching enabled.Accuracy
lm_evalgsm8k5-shot, 200 questions,temperature=0,max_gen_toks=2048, via/v1/completions:Within the historical 0.948-0.965 band on this model.
Performance (mt-bench,
philschmid/mt-bench, 80 prompts)MTP=2 peak: 165 tok/s single-stream, 846 tok/s @ c=24. MTP=2 acceptance length 2.35-2.38 on real-content prompts, pos-0 acceptance 84-85%.
Long-context prefill
Earlier long-context work in this PR added
_accumulate_indexed_attention_chunk_multihead_kernel(HEAD_BLOCK=8) and overlapped the C128A prefill KV gather with the indexer forward. The latest two commits add a direct SM120 MQA top-k fallback path: Triton materializes the FP8 MQA logits, then the existing customtop_k_per_row_prefillop selects top-k row indices without the slower PyTorch per-chunk score path.Dedicated 128K A/B sweep on the same 2x RTX PRO 6000 setup, C=1, cold,
max_tokens=64:f32b9e782)709f50d10)Conservative full-validation rerun on 2026-05-18, C=1, cold,
max_tokens=64, repeat=3:That conservative rerun is 37.1% lower TTFT than the pre-top-k 128K baseline (60.83 s -> 38.23 s), a 1.59x speedup.
Small-concurrency long-context matrix on 2026-05-18, cold salted prompts,
max_tokens=128, repeat=2, prefix cache disabled:Short-context warmed regression check, 4,047 prompt tokens, cold salted prompts,
max_tokens=64, repeat=2:No short-context regression versus the same-machine pre-top-k baseline (4,047 prompt tokens: C=1 0.689 s, C=2 1.125 s, C=4 2.072 s TTFT mean).
Post-rebase MTP C=4 stability and 64K/128K gate (2026-05-19)
After the upstream DeepSeek V4 refactor rebase, a short-context MTP C=4 stability blocker was traced to DeepSeek V4 MTP full decode CUDA graph replay. The current branch skips full decode CUDA graph capture for DeepSeek V4 MTP while keeping PIECEWISE CUDA graphs, and caps DeepSeek V4 MTP warmup / dummy sampler request shapes at 32. Serve logs for this run show
PIECEWISE=49and no full decode graph capture.Short-context MTP matrix, prefix cache disabled, 131K max-model-len, 4096 max-num-batched-tokens, TP=2, 16 prompts:
Full long-context promotion gate, prefix cache disabled, cold prompts,
max_tokens=128, repeat=3:GSM8K limit-200, 5-shot, MTP concurrency 1:
exact_match_flexible=0.960,exact_match_strict=0.955.Targeted regression tests for this fix:
Result:
4 passed, 16 warnings.Acceptance (toolcall-15 scenario battery)
This is the first SM12x baseline that evaluates thinking-mode correctly. Two prior harness bugs masked thinking-mode entirely across every earlier retry:
extra_body.thinking={"type":"enabled"}at the top level, which is the Claude API shape. vLLM's DSv4 chat-template entry readschat_template_kwargs.thinkinginstead, so every request silently routed to chat mode. Fixed by 323aa1f (confirmed in this PR discussion by qym-ll).message.reasoning_content, but this vLLM OpenAI frontend build populatesmessage.reasoning. The harness now normalizes both keys.The remaining failures stay concentrated in
TC-06(Multi-Value Extraction, 7/7 across modes) plus scattered TC-11 / TC-14 / TC-15: characteristic helpfulness-bias / deflect-rather-than-refuse model behaviours, not SM12x regressions.Comparison to DeepSeek's official hosted API
Same prompts run against
api.deepseek.com/v1/chat/completionswithmodel=deepseek-v4-flash, sametemperature=1.0 top_p=1.0, and the same thinking-mode shape:Per-case failure rate: hosted 4.4%, this PR 8.9-9.6%. The hosted service either ships a checkpoint we have not pulled from the HF release, or injects an internal tool-use system prompt. Either way the local vs hosted gap on this PR is the smallest it has been in any baseline shipped here.
vs 2026-05-12 deployment baseline (
1c20f1a6d, same hardware)Verification commands
Result:
4 passed, 16 warnings.Long-context matrix verification:
Results: 64K/128K matrix
PASS, 6 groups, 0 failures; warmed short-context matrixPASS, 3 groups, 0 failures.Known caveats
ProcessGroupWatchdog) and MTP=1 remains smoke-tier pending repro on NCCL 2.30.4+.vllm-project/vllm#42784fix means prefix cache does work on DSv4 SWA when enabled; a cache-on companion run is still useful for real document-chat deployment.Acknowledgments
_accumulate_indexed_attention_chunk_multihead_kernel, HEAD_BLOCK=8), patterned after the existing decode_finish_materialized_scores_with_sink_kernel._cache_block_maskover-aggression for Eagle/MTP groups, fixed byvllm-project/vllm#42784(cherry-picked locally pending upstream merge)._deepseek_v4_sm12x_fp8_einsum_kernelautotune key includingnum_tokens, causing per-request 4-config re-benchmarks; we pinned the winning config and removed the decorator._aux_stream[1]overlap ofdequantize_and_gather_k_cachewithindexer.forward).effective_topk, and the multi-head prefill kernel direction.AI assistance disclosure
Claude (Anthropic), GPT-5.4, and GPT-5.5 were used for code review, refactoring, regression-script writing, and benchmark analysis. All kernel logic and architectural decisions were validated by human review and end-to-end benchmarks before each push.