Skip to content

[WIP] feat: Add sm120 support for DeepGEMM #324

Open
leavelet wants to merge 29 commits into
deepseek-ai:mainfrom
leavelet:sm120
Open

[WIP] feat: Add sm120 support for DeepGEMM #324
leavelet wants to merge 29 commits into
deepseek-ai:mainfrom
leavelet:sm120

Conversation

@leavelet
Copy link
Copy Markdown

@leavelet leavelet commented May 1, 2026

Added sm120 support for DeepGEMM, with performance on par with cuBLAS and outperforming the CUTLASS version in most cases.

Important: This repository is still a work in progress and has not been cleaned up yet. It is being released early to prevent duplicate effort in the community.

sm120 support will be maintained in the nv_dev branch by NVIDIA DevTech APAC.

Performance metrics:

Dense GEMM

Precision Peak TFLOPS % of MMA Peak Notes
FP8 619 76% block-scaled UE8M0, BF16 output
FP4 1239 81% block-scaled UE8M0, gran_k=32, BF16 output
BF16 374 98.5% No scale factors
  • FP4 vs FP8 speedup: 1.6–1.9× on large shapes.
  • BF16 achieves the highest MMA utilization (98.5%) due to being more compute-bound.

Grouped GEMM (FP8 / FP4)

Kernel FP8 TFLOPS FP4 TFLOPS
K-grouped (EP64, 4 groups) 551 828
M-grouped contiguous (4g) 728 1372
M-grouped masked (6g) 633 1169
  • FP4 vs FP8 speedup: ~1.8× on M-grouped contiguous, ~1.7–1.8× on masked.
  • M-grouped contiguous exceeds dense TFLOPS due to larger effective M (~34K), giving better wave efficiency across SMs.

Einsum (Batched GEMM)

Variant FP8 TFLOPS BF16 TFLOPS
K-major B (bhr,hdr->bhd) 682 384
MN-major B (bhd,hdr->bhr) 664 ~410
Split-S reduction (bmk,bnk->mn) 188 (BW-limited)
  • MN-major B matches K-major performance (99–102%) via ldmatrix.trans.x2 + multi-atom BLOCK_N=128 optimization.

HC Prenorm (Fused GEMM + sqr_sum, TF32)

M K Splits TFLOPS GB/s
8192 28672 16 28.8 1242
4096 28672 16 25.1 1087
  • Inherently memory-bound (N=24). Split-K provides up to 7× speedup for small M.

MQA Logits (Attention, FP8)

Mode Peak TFLOPS Notes
Dense (ragged) 651 (80% peak) Warp-specialized, L2-cached KV
Paged 322 DRAM BW-limited (1.35 TB/s)

Co-authored by: @lucifer1004

leavelet and others added 10 commits April 26, 2026 09:54
Phase 1a: infrastructure + dense FP8 GEMM kernel for SM120a (CC 12.0).

Architecture: warp-level mma.sync with block-scaled UE8M0 scale factors,
B128 XOR swizzle, persistent scheduling, register-based epilogue.

New files:
- SM120 heuristics, JIT codegen, MMA PTX wrappers, ldmatrix/swizzle utils
- CUDA kernel with warp-specialized TMA/math pipeline (3-9 stages)

Modified files:
- Arch detection, compiler flags (-gencode for SM120a)
- API dispatch (arch_major == 12), SF layout transform
- Default recipe for SM120

Correctness: 8/8 shapes pass (diff < 0.001 cosine distance)
Performance: ~73 TFLOPS (baseline, optimization pending)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… only

Drop the non-warp-specialized kernel path for SM120a (matching SM90/SM100
architecture), merging the warp-specialized implementation into the main
sm120_fp8_fp4_gemm_1d1d kernel. Add FP4 GEMM support using packed SMEM
with the mxf4nvf4 m16n8k64 MMA instruction.

Key changes:
- Consolidate: remove non-spec path, always BM=128/BK=128/384 threads
- FP4: packed 4-bit SMEM (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B), standard
  ldmatrix, uint16_t scale factors (scale_vec::2X), kKSteps=2 vs FP8's 4
- Heuristic: simplified to warp-spec only, correct SMEM sizing for FP4
- API: enable FP4 on SM120a (arch_major==12), add fp8_fp4_gemm_nt binding
- Fix SF hoist bug: hoist SFA/SFB independently for mixed gran_k configs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Kernel: Add TMA descriptor runtime update (tensormap.replace) in producer
  loop for K-grouped group transitions, fix SF_K_ALIGNMENT to kGranKA*4,
  fix SMEM layout (pipeline data at offset 0 for B128 swizzle alignment,
  tensor map descriptors at end), fix epilogue bounds for multi-group output.

- MMA wrappers: Replace CUTLASS mma_sm120.hpp dependency with custom inline
  asm using "+f" read-write constraints for accumulator registers. Eliminates
  CUTLASS header dependency and gives explicit control over MMA operand
  encoding for both FP8 (m16n8k32) and FP4 (m16n8k64) block-scaled MMA.

- JIT launcher: Add sm120_k_grouped_fp8_fp4_gemm_1d1d() with proper TMA
  descriptor creation (first_k base, FP4-aware stride), SF TMA covering
  concatenated groups, CD TMA with num_groups outer dimension.

- API dispatch: Add arch_major==12 path in k_grouped_fp8_gemm_nt_contiguous,
  relax recipe assertion to support gran_k=32/128, add SM120 SF layout
  transform with auto-detection of transposed K-major scale factors.

- Tests: Add dedicated SM120 K-grouped test (7 configs including zero-K
  edge case), fix K-major selection for SM120 in generators, fix test
  dispatch for SM120 in test_fp8_fp4.py, update FP4 test with perf comparison.

Tested: Dense FP8 8/8, Dense FP4 10/10, K-grouped FP8 7/7 — all PASS.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Einsum support:
- Add GemmType::Batched to FP8/FP4 and BF16 kernels with 3D TMA load/store
- Add IndexType::SF_K for batched SF coordinate computation
- Add MN-major B support to BF16 kernel (scalar SMEM loads, single-atom constraint)
- BF16 bhr,hdr->bhd: 384 TFLOPS, FP8: 681 TFLOPS (batch=8, b=8192)

M-grouped BF16:
- Add contiguous and masked M-grouped BF16 GEMM launchers

HC prenorm TF32:
- New fused GEMM + sqr_sum kernel using mma.sync.m16n8k8 TF32 (226T peak)
- BF16 A -> FP32 cast with fused sqr_sum accumulation
- Atom-aware FP32 B fragment loading from K-major SMEM
- Split-K support for large K / small M shapes
- 24/24 test shapes PASS, ~1.1 TB/s bandwidth on large shapes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Dense (ragged): 651 TFLOPS peak (80% of FP8 MMA peak 814T), 40/40 tests pass.
Paged (KV cache): 320 TFLOPS peak, 1.36 TB/s DRAM (91% of HBM BW), 8/8 tests pass.

Kernel design: warp-specialized mma.sync m16n8k32 FP8 (no block_scale).
8 math warps × 16 KV rows each = 128 BLOCK_KV. In-warp 2-shfl reduction
across 4 threads (lane%4) — only ~10 cycles, negligible vs MMA time.
Global stores are fire-and-forget on SM120a, so no epilogue warps needed.

Key parameters: block_qh=128, num_heads=64, head_dim=128, 2 Q stages,
3 KV stages, 84KB SMEM (83% of 101KB capacity).

Paged variant: 2 groups of 4 warps, SPLIT_KV=128, per-group KV pipeline.
Fixed metadata split_kv mismatch and register budget overflow (TMA regs
64→40 to stay within 65536 register limit).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- skip_head_mid: Add SM120 dispatch in attention.hpp for EpilogueHeadSplits.
  Fix three issues: TMA CD descriptor uses d.size(-1)/d.stride(-2) instead
  of n; kernel uses stride_d parameter for D row stride and bounds checks;
  TMA store coordinates apply epilogue N-index remapping.

- MN-major B: Fix kernel TMA coordinate for M-grouped BF16 with MN-major B.
  Group offset moves to outer=K coordinate (not inner=N) when kBKMajor=false.

- FP8 kernel stride_d: Add stride_d parameter to decouple D tensor stride
  from computation dimension n, enabling epilogue transforms that expand N.
Replace per-element scalar SMEM loads in the MN-major B path with
ldmatrix.sync.aligned.x2.m8n8.trans.shared.b16 which natively loads
column-major 8x8 BF16 matrices — directly producing MMA B fragments.

Performance: MN-major B improves from ~220T to ~290T (dense) and ~330T
(M-grouped), a 30-50% gain. Remaining gap vs K-major (400T) is due to
heuristic selecting BLOCK_N=32/64 vs 128 (single swizzle atom constraint).

Verified by micro benchmark b_bf16_4: fragment layout 32x32 lanes PASS,
MMA pipeline 4 K-steps accumulation PASS.
Remove single-atom BLOCK_N constraint for MN-major B. ldmatrix.trans
correctly handles multi-atom SMEM (verified by micro benchmark b_bf16_5).

MN-major B now achieves 99-102% of K-major performance (was 53% with
scalar loads, 80% with single-atom ldmatrix.trans).
…bhr->hdr

New BF16 bmk,bnk->mn reduction kernel with split-S atomicAdd to FP32 output
(188T peak, HBM BW limited). FP8 einsum dispatch for bhd,hdr->bhr and
bhd,bhr->hdr via .contiguous() to K-major. Fix batched epilogue stride
formula: replace single stride_d with stride_cd_m + stride_cd_batch to
support arbitrary D layouts ([batch,M,N] vs [M,batch,N]). Add kBKMajor
template parameter to FP8 kernel with verified scalar-load MN-major B path
(correct but 3x slower than K-major ldmatrix, kept for future optimization).
@leavelet leavelet mentioned this pull request May 1, 2026
@leavelet leavelet changed the title [WIP] Feat: Add sm120 support for DeepGEMM [WIP] feat: Add sm120 support for DeepGEMM May 1, 2026
leavelet added 9 commits May 1, 2026 03:05
K-grouped TN: single .t().contiguous() transpose with constant-stride TMA
(kKGroupedConstantStride) — per-group only replaces addr+dim, not stride.
TN achieves 99-101% of NT performance. New PTX tensor_map_replace_global_dim_in_smem.

Paged MQA varlen: fix 4 kernel bugs in sm120_fp8_paged_mqa_logits.cuh:
- TMA Q coordinate: use atom_to_token_idx() instead of hardcoded *kNextNAtom
- Prefetch advance: use get_atom_advance() instead of hardcoded +1
- Math loop: conditional iteration count via is_paired_atom for unpaired atoms
- KV block idx: reset kv_block_idx_ptr=32 on q_atom change
New kernels using mma.sync m16n8k64 block-scaled FP4 (mxf4nvf4, scale_vec::2X).
Architecture: 8 math warps + 4 TMA warps, B64 swizzle, kKSteps=2.
Block-scaled MMA folds UE8M0 SF into computation — no post-MMA scale.

Dense FP4 MQA: 1022 TFLOPS peak (63% FP4 peak), 1.6x vs FP8.
Paged FP4 MQA: 707 TFLOPS peak (varlen), 566 TFLOPS (non-varlen next_n=2).
…dmatrix

Uses mxf8f6f4 block_scale e4m3.e2m1 MMA (K=32, scale_vec::1X).
B path: TMA 16U4_ALIGN16B → padded SMEM → ldmatrix.m8n16.x2.b8x16.b4x16_p64
(hardware unpack 4-bit→8-bit) → <<2 shift → mixed MMA. Zero host-side overhead.
SMEM budget identical to pure FP8 (padded row = BLOCK_K bytes).
Peak 670 TFLOPS on 4096×24576×1536.
…GEMM

Extend b_is_fp4 detection and TMA .b4x16_p64 descriptor to all 4 grouped
launchers (K-grouped, M-grouped contiguous, M-grouped masked, batched).
Also fix K-grouped sum_sf_k bug: split into sum_sf_k_a/sum_sf_k_b for
correct SF TMA dimensions when gran_k_a != gran_k_b.
…uards to MQA kernels

BF16 kernel TMA store path was computing N coordinate directly without
epilogue_type_t::apply_index_n, unlike the FP8/FP4 kernel which correctly
applies the transform. Also adds #pragma clang diagnostic and
#if __CUDA_ARCH__ >= 1200 guards to all 4 SM120 MQA logits kernels,
matching the pattern used by all other SM120 kernel files.
Sub-tile epilogue: stream 64-row sub-tiles through compact SMEM_D buffer
instead of full 128-row tile. SMEM_D reduced from 32KB to 16KB for BN=128,
recovering 1 pipeline stage (3→4 for FP4 BN=128). Heuristic auto-selects
sub-tile when it gains stages. TMA CD descriptor box outer dim matches
kEpiSubM. tma_store_wait<0> per sub-tile ensures SMEM coherence.

B fragment loading: use ldmatrix.x4 for B (2 N-tiles per load) instead of
ldmatrix.x2 (1 N-tile per load), halving B load instructions per K-step
from 16 to 8. Reduces L1/TEX pipeline pressure from 72.6% to 61.4%,
freeing bandwidth for TMA loads. MMA wrappers accept b0/b1 as separate
uint32_t args to allow direct x4 output indexing without intermediate
array copy.

FP4 4096×7168×2048: 1024T → 1091T (+6.5%)
FP4 4096×2112×7168: 1101T → 1211T (+10.0%)
FP8 4096×24576×1536: 619T → 637T (+2.9%)
All 411+ tests pass.
Replace cross-N-tile ldmatrix.x4 (2 N-tiles per load, requires MOV
rearrangement) with per-N-tile ldmatrix.x4 (1 N-tile covering 2 K-steps).
Output {r0,r1}=ks0 and {r2,r3}=ks1 are consecutive register pairs consumed
directly by MMA — verified zero MOV in SASS inner loop (27 total kernel
MOV vs 87 before, all in prologue/epilogue).

Pipeline restructured: B loaded upfront for entire K-step pair, A
double-buffered across K-steps, SF properly hoisted when gran_k >= BLOCK_K.

Fallback path retained for MN-major B and mixed FP8xFP4.

Also adds tests/sm120/run_regression.sh for automated test suite.
Split 8 math warps from (1M×16N) to (2M×8N) cooperative layout:
each warp handles 2 M-tiles × 8 N-tiles instead of 1 × 16.

Per K-block: B loads 16→8 LDSM, A loads 2→4 LDSM, net 18→12 LDSM.
L1/TEX throughput drops from 61% to 41% — SMEM load no longer bottleneck.
Tensor pipe rises to 72% of peak (from 71%).

FP4 4096×7168×2048: 1094T → 1144T (+4.6%)
FP4 4096×2112×7168: 1138T → 1278T (+12.3%)
FP4 4096×24576×1536: 1132T → 1168T (+3.2%, exceeds CUTLASS 1146T)
FP8 also improves: 597→634T, 587→622T, 637→665T
Update MMA efficiency model to better account for epilogue overhead and
A fragment reuse with cooperative warp layout. Larger BN=128 has half the
tile count (fewer epilogues) and 2x compute per K-block (better TMA
latency amortization), outweighing the stage count reduction (4 vs 7).

FP4 4096×7168×2048: 1144T → 1201T (+5.0%)
FP4 4096×2112×7168: 1278T → 1307T (+2.3%)
leavelet added 2 commits May 7, 2026 01:43
K-grouped FP4 was an SM120-exclusive feature not needed by DeepSeek V4
(which uses M-grouped FP4 for routed MoE). SM100 never supported it.
Add assert to block FP4 inputs in k_grouped_fp8_gemm_nt_contiguous.
Remove FP4 test configs from test_k_grouped_fp8.py.
SM120 mma.sync cannot natively handle MN-major operands (ldmatrix.trans
only works at b16 granularity, incompatible with 1-byte FP8). Convert
MN-major to K-major at the API dispatch layer:
- FP8: .contiguous() transpose copy
- FP4: fp4_repack_to_k_major() — unpack nibbles, rearrange, repack
  (simple .contiguous() corrupts FP4 nibble pairing)

Also fixes SF transpose kernel SMEM overflow on SM120 (99KB limit) for
large SF_K by dynamically computing block_mn from device SMEM capacity.
@leavelet leavelet marked this pull request as ready for review May 9, 2026 04:37
@linjiapro
Copy link
Copy Markdown

@leavelet this is nice, after it is merged into the nv-dev branch, should vllm-project/vllm#41834 merge too in order for vllm to be able to work with the branch.

@jasl
Copy link
Copy Markdown
Contributor

jasl commented May 10, 2026

I add my benchmark at vllm-project/vllm#41834 as references

jasl added a commit to jasl/tokenspeed that referenced this pull request May 12, 2026
Replace the hand-written CUDA FP8 GEMV kernel (previously gated to
tokens==1) with a port of the SM120 FP8 einsum kernel from upstream
DeepGEMM's WIP SM120 support (deepseek-ai/DeepGEMM#324, file
`deep_gemm/include/deep_gemm/impls/sm120_fp8_einsum.cuh`). The DeepGEMM
kernel implements exactly the `bhr,hdr->bhd` einsum DeepSeek V4 needs,
with per-thread-per-output-cell GEMV using fp8x4 vectorized loads and
the same block-128 fp32 scale recipe.

Removing the tokens==1 gate: the kernel handles all token counts that
the SM12x dispatch predicate accepts (tokens <= 16 today; larger token
batches will arrive once T1-α expands graph capture).

Microbench (DSv4-Flash decode shape, groups=8, hidden=2048, out=1024,
GPU idle):

  tokens=1  cuda 0.026ms  triton 0.020ms  speedup 0.72x (was 0.40x)
  tokens=2  cuda 0.027ms  triton 0.020ms  speedup 0.72x (was 0.73x*)
  tokens=8  cuda 0.075ms  triton 0.020ms  speedup 0.27x (was 0.21x*)

  * Triton-as-default after the previous tokens==1 hotfix.

The kernel's grid is `tokens * groups * (out/128)`, one block per
`(token, group, out_tile=128)` triple. Because each block reads its
weight tile independently, total weight reads scale linearly with
`num_tokens`. At graph bs=2 (today) this dominates: tokens<=2 is the
production shape and the 0.72x is a real net win against the previous
Triton-fallback default. At tokens=8 (future, post-T1-α) the kernel
loses ~2x to Triton's m=16 cooperative tile; we will revisit with a
multi-token tile design before T1-α exposes that shape to production.

Earlier hand-written attempts (one-cell-per-block, per-thread B=16
accumulator tile, 1-warp m16n8 MMA, 4-warp m16n32 cooperative MMA,
4-warp m16n128 MMA) are documented in
`docs/notes/2026-05-09-ds4-sm12x-rejected-experiments.md`. The MMA
designs hit either occupancy collapse (80 regs/thread) or insufficient
parallelism (64 blocks at decode shape vs Blackwell's 140 SMs), capping
out at ~0.51x. The DeepGEMM design wins at the production shape by
avoiding tensor cores entirely -- a per-thread GEMV with fp8x4
vectorization and L1/L2-friendly weight access fits the small-M decode
profile better than the m=16 MMA tile.

Attribution: kernel source ported under MIT license from upstream
DeepGEMM (Copyright (c) 2025 DeepSeek). Tokenspeed adaptations are the
tvm-ffi binding, stride/scale validation, and the SM12x dispatch
integration; the dot-product math is unchanged.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/tokenspeed that referenced this pull request May 12, 2026
Upstream PR lightseekorg#93 added a pre-flight DeepGEMM ``fp8_gemm_nt`` call to
``DeepseekV4Attention._compute_qr_kv``: on success it replaces the
reference FP8 linear path, on failure it logs a WARNING per layer and
falls back. DeepGEMM does not support SM120/SM121 yet (see PR
``deepseek-ai/DeepGEMM#324`` + ``reference_deepgemm_sm120`` memory),
so on the RTX Pro 6000 workstation every layer fires:

    DeepSeek V4 DeepGEMM FP8 linear failed; falling back to reference
    FP8 linear. reason=RuntimeError: Assertion error
    (csrc/apis/layout.hpp:59): Unknown SF transformation

The existing per-layer ``_deepseek_v4_deep_gemm_linear_disabled`` flag
already catches this for steady-state replay, but it costs one failed
call + one WARNING per layer at boot. Mirror the pattern used by
``_deepseek_v4_deepgemm_fp4_indexer_enabled_for_platform``: short-
circuit ``_deepseek_v4_get_fp8_linear_deep_gemm`` to ``None`` on SM12x
so the platform never tries the DeepGEMM path. Non-SM12x platforms
keep the new fast path.

Signed-off-by: jasl <jasl9187@hotmail.com>
leavelet and others added 8 commits May 13, 2026 03:28
FP8 paged kernel hardcodes 2-group layout (BLOCK_KV=64), so block_kv=32
causes SMEM overflow. FP4 paged kernel correctly computes num_groups from
block_kv. Add assert in FP8 launcher, enable block_kv=32 only for FP4
in test. Remove redundant test_mqa_logits.py wrapper.
…text hoist

Restructure the FP8/FP4 mainloop to reduce non-QMMA instruction overhead:

- SF-major loop: load SF packed int32 into registers once per
  kSFTileKBlocks K-blocks, extract bytes with compile-time index
  via cute::for_each(make_int_sequence). Eliminates redundant SF
  SMEM loads (10 LDS/K-block → 2.5 amortized).

- Compile-time SF byte extraction: cute::for_each guarantees
  kb_inner is a compile-time type, enabling the compiler to
  specialize byte 0 (LOP3 only) and byte 3 (SHF only).

- SwizzleContext hoist: move a_ctx/b_ctx initialization outside
  the K-block loop (loop-invariant).

- Generalized sf_byte formula: (kb_inner * BLOCK_K / kGranK) % 4
  supports both BLOCK_K=128 and BLOCK_K=64.

- if constexpr(kUseSFMajorLoop) separates the optimized path
  (gran_k >= BLOCK_K) from the original flat loop (gran_k < BLOCK_K),
  eliminating ~300 lines of tail code duplication.

SASS improvement (BLOCK_K=128): total mainloop instructions 587→469,
QMMA ratio 43.6%→54.6%, IMAD 13→0, LOP3 71→34.
Add BLOCK_K=64 as a heuristic candidate for M >= 2048 non-mixed
dtype GEMMs. With BLOCK_K=64, per-stage SMEM halves (17KB vs 34KB),
enabling 4 pipeline stages (vs 2 with BLOCK_K=128). The deeper
pipeline provides 3 stages of TMA latency slack (vs 1), improving
sustained MMA pipe utilization.

Benchmark (M=4096, FP8 dense):
- 4096×2112×7168:  591T → 624T (+5.7%)
- 4096×5120×8192:  669T → 699T (+4.6%)
- 4096×5120×25600: 689T → 731T (+6.0%)

Constraints:
- M < 2048: BK=128 is faster (pipeline fill/drain overhead dominates)
- Mixed FP8×FP4: BK=128 only (FP4 .b4x16_p64 TMA incompatible)
- B64 swizzle: consequence of BK=64, verified correct on SM120a
The grouped contiguous MoE path on SM120 previously hardcoded BLOCK_M=128
in `get_layout_candidates`, and `get_theoretical_mk_alignment_for_contiguous_layout`
matched it with a 128-only floor. For batch=1 decode with 128 experts and
topk=6, this forces M_sum = 6 + 128*127 = 16384 padded rows of which only
6 are valid — kernel time scales with M_sum, so ~99% of work is on
padding blocks the kernel still has to schedule through.

Enable BLOCK_M=64 as a candidate (alongside 128 plus runtime_align if it
sits on a 16-aligned value in (64,128]); the existing cycle-based scorer
in `compare()` picks the smaller tile when expected_m is small. The
kernel-side cooperative-warp layout (kMWarps=4, MMA_M=16, kMTilesPerWarp =
BLOCK_M/64) requires BLOCK_M >= 64, so 64 is the practical floor for now;
unlocking BLOCK_M=32 needs the kMWarps refactor.

Microbenched on RTX PRO 6000 Blackwell at DSv4-Flash MoE shapes
(N=4096/2048, K=4096/2048, num_groups=128). FC1 grouped-GEMM time
medians at batch=1 decode dropped from 57 us → 37 us (-35%) and the
end-to-end vLLM TPOT at TP=2 / 8K-1K dropped from 15.76 ms → 12.48 ms,
beating the Marlin MoE path baseline (13.96 ms) by ~11%.
Lifts the BLOCK_M floor from 64 to 32 by switching the kernel's
cooperative warp layout from the hardcoded `kNWarps = 2` (4 M-warps ×
2 N-warps) to a BLOCK_M-dependent value:

  BLOCK_M  < 64 → kNWarps = 4, kMWarps = 2   (smaller A tile)
  BLOCK_M >= 64 → kNWarps = 2, kMWarps = 4   (default, more A reuse)

This keeps kMTilesPerWarp = BLOCK_M / kMWarps / MMA_M = 1 in both cases
and adds a static assert. For the kNWarps=4 path we need BLOCK_N % 32 == 0
so kNTiles = BLOCK_N/8 is divisible by kNWarps; sm120.hpp filters
(block_m, block_n) accordingly.

The runtime alignment helper's SM120 min_block_m drops from 64 to 32
so `get_theoretical_mk_alignment_for_contiguous_layout(expected_m)`
returns 32 for `expected_m <= 32` (e.g. MoE decode batch=1, topk=6:
`expected_m = M*topk = 6` → align=32 → M_sum = 6 + 6*31 = 192, vs 384
at BLOCK_M=64, or the original 16384).

To avoid the cycle-based `compare()` scorer over-selecting BLOCK_M=32
for prefill-sized expected_m (where the launch/pipeline overhead of
many tiny tiles isn't in the model), small-BLOCK_M candidates are
filtered to `expected_m <= 4 * BLOCK_M`. That keeps batch=1..8 on
BLOCK_M=32 and batch≥16 on BLOCK_M=128.

Microbench on RTX PRO 6000 Blackwell, DSv4-Flash MoE shapes (N=4096,
K=4096, num_groups=128), FC1 grouped-GEMM time medians:

  batch  M_sum   BLOCK_M=64   BLOCK_M=32+gate   Marlin baseline
      1    192       37.4 us          29.0 us           31.2 us
      2    384       55.9            37.3              54.9
      4    768      168.5            56.9             135.6
      8   2304      337.6           151.2             247.4
     32  16512      949.8           943.2             589.5
   1024  22400     1251.9          1210.9             833.1

End-to-end vLLM bench (DSv4-Flash, TP=2, 8K/1K decode), median TPOT:

  Marlin baseline:        13.96 ms
  DEEPGEMM_MXFP4
    + BLOCK_M=64:         12.48 ms   (-11% vs Marlin)
    + BLOCK_M=32 gated:   (pending — verify e2e in vLLM)
Mirrors TRT-LLM blockscale-gemm runGemmSwapAB
(cpp/include/tensorrt_llm/deep_gemm/fp8_gemm.cuh). The SM120 1d1d kernel
has BLOCK_M >= 64, so decode-shape BMM with M_orig <= 32 wastes most M
lanes. The caller in apis/einsum.hpp::fp8_bmm now detects this case,
swaps A<->B (and sfa<->sfb before the SF layout transform), and calls
sm120_fp8_fp4_bmm with swap_ab=true. The heuristic relaxes the BLOCK_N
floor only when desc.n <= 32 (post-swap signature), so the normal path
keeps BN in {64, 128}.

The kernel takes a new runtime stride_cd_n. In normal mode it is 1 (cols
adjacent, pair-write fast path). In swap mode it is d.stride(-2) = N_orig,
so the kernel's (m_kernel=n_orig, n_kernel=m_orig) writes land byte-for-
byte in the caller's (B, M_orig, N_orig) row-major buffer with no temp
allocation. The direct-store epilogue picks pair vs single-element based
on stride_cd_n == 1. swizzle_cd_mode is forced to 0 below BN*cd_size=128
so kUseTMAStoreEpilogue is off in swap mode (the natural TMA-CD layout
can't write the swapped indices correctly).

Correctness: bit-identical output to jasl Triton einsum across M=1..32
(torch.equal == True, max abs diff 0).
Perf: ~25% faster than jasl at kernel level (ncu) across M=1..16.
Normal path (M >= 33) unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wires the same swap-AB dispatch into apis/gemm.hpp::fp8_fp4_gemm_nt so
decode-shape MLP / projection GEMMs (M_orig <= 32) also bypass the
BLOCK_M = 64 floor. sm120_fp8_fp4_gemm_1d1d gains the swap_ab flag and
runtime stride_cd_n the bmm variant already used, so the kernel writes
directly into the caller's (M_orig, N_orig) row-major buffer with no
temp allocation.

Recipes can be asymmetric on plain GEMM (per-token A + per-block B
defaults to (1, 128, 128) on SM12). When we swap operands we must also
swap the recipe's gran_mn entries; otherwise transform_sf_into_required
trips on sf.size(-2) != ceil_div(mn, gran_mn). Both gemm.hpp and
einsum.hpp now build a swapped recipe and use the recipe_a/recipe_b
form of transform_sf_pair_into_required_layout. einsum.hpp was correct
for DSv4's symmetric (1, 1, 128) recipe before this change but would
have silently produced shape errors on any asymmetric caller.

Verified: 11/11 plain-GEMM correctness shapes at M=1..32 (swap) and
M=64,128 (normal) within FP8 tolerance; all upstream SM120 regression
tests pass (8/8 dense FP8, 7/7 k-grouped, 34/34 m-grouped FP8/FP4);
einsum BMM swap path remains bit-identical to jasl Triton at M=1..32.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jasl added a commit to jasl/vllm that referenced this pull request May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 20, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
@Rachmanino
Copy link
Copy Markdown

nice work! may I ask the hardware for testing here is either 5090 or RTX6000pro?

jasl added a commit to jasl/vllm that referenced this pull request May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
jasl added a commit to jasl/vllm that referenced this pull request May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch.

The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code.

On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95.

Signed-off-by: jasl <jasl9187@hotmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants