[WIP] feat: Add sm120 support for DeepGEMM #324
Open
leavelet wants to merge 29 commits into
Open
Conversation
Phase 1a: infrastructure + dense FP8 GEMM kernel for SM120a (CC 12.0). Architecture: warp-level mma.sync with block-scaled UE8M0 scale factors, B128 XOR swizzle, persistent scheduling, register-based epilogue. New files: - SM120 heuristics, JIT codegen, MMA PTX wrappers, ldmatrix/swizzle utils - CUDA kernel with warp-specialized TMA/math pipeline (3-9 stages) Modified files: - Arch detection, compiler flags (-gencode for SM120a) - API dispatch (arch_major == 12), SF layout transform - Default recipe for SM120 Correctness: 8/8 shapes pass (diff < 0.001 cosine distance) Performance: ~73 TFLOPS (baseline, optimization pending) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… only Drop the non-warp-specialized kernel path for SM120a (matching SM90/SM100 architecture), merging the warp-specialized implementation into the main sm120_fp8_fp4_gemm_1d1d kernel. Add FP4 GEMM support using packed SMEM with the mxf4nvf4 m16n8k64 MMA instruction. Key changes: - Consolidate: remove non-spec path, always BM=128/BK=128/384 threads - FP4: packed 4-bit SMEM (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B), standard ldmatrix, uint16_t scale factors (scale_vec::2X), kKSteps=2 vs FP8's 4 - Heuristic: simplified to warp-spec only, correct SMEM sizing for FP4 - API: enable FP4 on SM120a (arch_major==12), add fp8_fp4_gemm_nt binding - Fix SF hoist bug: hoist SFA/SFB independently for mixed gran_k configs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Kernel: Add TMA descriptor runtime update (tensormap.replace) in producer loop for K-grouped group transitions, fix SF_K_ALIGNMENT to kGranKA*4, fix SMEM layout (pipeline data at offset 0 for B128 swizzle alignment, tensor map descriptors at end), fix epilogue bounds for multi-group output. - MMA wrappers: Replace CUTLASS mma_sm120.hpp dependency with custom inline asm using "+f" read-write constraints for accumulator registers. Eliminates CUTLASS header dependency and gives explicit control over MMA operand encoding for both FP8 (m16n8k32) and FP4 (m16n8k64) block-scaled MMA. - JIT launcher: Add sm120_k_grouped_fp8_fp4_gemm_1d1d() with proper TMA descriptor creation (first_k base, FP4-aware stride), SF TMA covering concatenated groups, CD TMA with num_groups outer dimension. - API dispatch: Add arch_major==12 path in k_grouped_fp8_gemm_nt_contiguous, relax recipe assertion to support gran_k=32/128, add SM120 SF layout transform with auto-detection of transposed K-major scale factors. - Tests: Add dedicated SM120 K-grouped test (7 configs including zero-K edge case), fix K-major selection for SM120 in generators, fix test dispatch for SM120 in test_fp8_fp4.py, update FP4 test with perf comparison. Tested: Dense FP8 8/8, Dense FP4 10/10, K-grouped FP8 7/7 — all PASS. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Einsum support: - Add GemmType::Batched to FP8/FP4 and BF16 kernels with 3D TMA load/store - Add IndexType::SF_K for batched SF coordinate computation - Add MN-major B support to BF16 kernel (scalar SMEM loads, single-atom constraint) - BF16 bhr,hdr->bhd: 384 TFLOPS, FP8: 681 TFLOPS (batch=8, b=8192) M-grouped BF16: - Add contiguous and masked M-grouped BF16 GEMM launchers HC prenorm TF32: - New fused GEMM + sqr_sum kernel using mma.sync.m16n8k8 TF32 (226T peak) - BF16 A -> FP32 cast with fused sqr_sum accumulation - Atom-aware FP32 B fragment loading from K-major SMEM - Split-K support for large K / small M shapes - 24/24 test shapes PASS, ~1.1 TB/s bandwidth on large shapes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Dense (ragged): 651 TFLOPS peak (80% of FP8 MMA peak 814T), 40/40 tests pass. Paged (KV cache): 320 TFLOPS peak, 1.36 TB/s DRAM (91% of HBM BW), 8/8 tests pass. Kernel design: warp-specialized mma.sync m16n8k32 FP8 (no block_scale). 8 math warps × 16 KV rows each = 128 BLOCK_KV. In-warp 2-shfl reduction across 4 threads (lane%4) — only ~10 cycles, negligible vs MMA time. Global stores are fire-and-forget on SM120a, so no epilogue warps needed. Key parameters: block_qh=128, num_heads=64, head_dim=128, 2 Q stages, 3 KV stages, 84KB SMEM (83% of 101KB capacity). Paged variant: 2 groups of 4 warps, SPLIT_KV=128, per-group KV pipeline. Fixed metadata split_kv mismatch and register budget overflow (TMA regs 64→40 to stay within 65536 register limit). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- skip_head_mid: Add SM120 dispatch in attention.hpp for EpilogueHeadSplits. Fix three issues: TMA CD descriptor uses d.size(-1)/d.stride(-2) instead of n; kernel uses stride_d parameter for D row stride and bounds checks; TMA store coordinates apply epilogue N-index remapping. - MN-major B: Fix kernel TMA coordinate for M-grouped BF16 with MN-major B. Group offset moves to outer=K coordinate (not inner=N) when kBKMajor=false. - FP8 kernel stride_d: Add stride_d parameter to decouple D tensor stride from computation dimension n, enabling epilogue transforms that expand N.
Replace per-element scalar SMEM loads in the MN-major B path with ldmatrix.sync.aligned.x2.m8n8.trans.shared.b16 which natively loads column-major 8x8 BF16 matrices — directly producing MMA B fragments. Performance: MN-major B improves from ~220T to ~290T (dense) and ~330T (M-grouped), a 30-50% gain. Remaining gap vs K-major (400T) is due to heuristic selecting BLOCK_N=32/64 vs 128 (single swizzle atom constraint). Verified by micro benchmark b_bf16_4: fragment layout 32x32 lanes PASS, MMA pipeline 4 K-steps accumulation PASS.
Remove single-atom BLOCK_N constraint for MN-major B. ldmatrix.trans correctly handles multi-atom SMEM (verified by micro benchmark b_bf16_5). MN-major B now achieves 99-102% of K-major performance (was 53% with scalar loads, 80% with single-atom ldmatrix.trans).
…bhr->hdr New BF16 bmk,bnk->mn reduction kernel with split-S atomicAdd to FP32 output (188T peak, HBM BW limited). FP8 einsum dispatch for bhd,hdr->bhr and bhd,bhr->hdr via .contiguous() to K-major. Fix batched epilogue stride formula: replace single stride_d with stride_cd_m + stride_cd_batch to support arbitrary D layouts ([batch,M,N] vs [M,batch,N]). Add kBKMajor template parameter to FP8 kernel with verified scalar-load MN-major B path (correct but 3x slower than K-major ldmatrix, kept for future optimization).
Open
K-grouped TN: single .t().contiguous() transpose with constant-stride TMA (kKGroupedConstantStride) — per-group only replaces addr+dim, not stride. TN achieves 99-101% of NT performance. New PTX tensor_map_replace_global_dim_in_smem. Paged MQA varlen: fix 4 kernel bugs in sm120_fp8_paged_mqa_logits.cuh: - TMA Q coordinate: use atom_to_token_idx() instead of hardcoded *kNextNAtom - Prefetch advance: use get_atom_advance() instead of hardcoded +1 - Math loop: conditional iteration count via is_paired_atom for unpaired atoms - KV block idx: reset kv_block_idx_ptr=32 on q_atom change
New kernels using mma.sync m16n8k64 block-scaled FP4 (mxf4nvf4, scale_vec::2X). Architecture: 8 math warps + 4 TMA warps, B64 swizzle, kKSteps=2. Block-scaled MMA folds UE8M0 SF into computation — no post-MMA scale. Dense FP4 MQA: 1022 TFLOPS peak (63% FP4 peak), 1.6x vs FP8. Paged FP4 MQA: 707 TFLOPS peak (varlen), 566 TFLOPS (non-varlen next_n=2).
…dmatrix Uses mxf8f6f4 block_scale e4m3.e2m1 MMA (K=32, scale_vec::1X). B path: TMA 16U4_ALIGN16B → padded SMEM → ldmatrix.m8n16.x2.b8x16.b4x16_p64 (hardware unpack 4-bit→8-bit) → <<2 shift → mixed MMA. Zero host-side overhead. SMEM budget identical to pure FP8 (padded row = BLOCK_K bytes). Peak 670 TFLOPS on 4096×24576×1536.
…GEMM Extend b_is_fp4 detection and TMA .b4x16_p64 descriptor to all 4 grouped launchers (K-grouped, M-grouped contiguous, M-grouped masked, batched). Also fix K-grouped sum_sf_k bug: split into sum_sf_k_a/sum_sf_k_b for correct SF TMA dimensions when gran_k_a != gran_k_b.
…uards to MQA kernels BF16 kernel TMA store path was computing N coordinate directly without epilogue_type_t::apply_index_n, unlike the FP8/FP4 kernel which correctly applies the transform. Also adds #pragma clang diagnostic and #if __CUDA_ARCH__ >= 1200 guards to all 4 SM120 MQA logits kernels, matching the pattern used by all other SM120 kernel files.
Sub-tile epilogue: stream 64-row sub-tiles through compact SMEM_D buffer instead of full 128-row tile. SMEM_D reduced from 32KB to 16KB for BN=128, recovering 1 pipeline stage (3→4 for FP4 BN=128). Heuristic auto-selects sub-tile when it gains stages. TMA CD descriptor box outer dim matches kEpiSubM. tma_store_wait<0> per sub-tile ensures SMEM coherence. B fragment loading: use ldmatrix.x4 for B (2 N-tiles per load) instead of ldmatrix.x2 (1 N-tile per load), halving B load instructions per K-step from 16 to 8. Reduces L1/TEX pipeline pressure from 72.6% to 61.4%, freeing bandwidth for TMA loads. MMA wrappers accept b0/b1 as separate uint32_t args to allow direct x4 output indexing without intermediate array copy. FP4 4096×7168×2048: 1024T → 1091T (+6.5%) FP4 4096×2112×7168: 1101T → 1211T (+10.0%) FP8 4096×24576×1536: 619T → 637T (+2.9%) All 411+ tests pass.
Replace cross-N-tile ldmatrix.x4 (2 N-tiles per load, requires MOV
rearrangement) with per-N-tile ldmatrix.x4 (1 N-tile covering 2 K-steps).
Output {r0,r1}=ks0 and {r2,r3}=ks1 are consecutive register pairs consumed
directly by MMA — verified zero MOV in SASS inner loop (27 total kernel
MOV vs 87 before, all in prologue/epilogue).
Pipeline restructured: B loaded upfront for entire K-step pair, A
double-buffered across K-steps, SF properly hoisted when gran_k >= BLOCK_K.
Fallback path retained for MN-major B and mixed FP8xFP4.
Also adds tests/sm120/run_regression.sh for automated test suite.
Split 8 math warps from (1M×16N) to (2M×8N) cooperative layout: each warp handles 2 M-tiles × 8 N-tiles instead of 1 × 16. Per K-block: B loads 16→8 LDSM, A loads 2→4 LDSM, net 18→12 LDSM. L1/TEX throughput drops from 61% to 41% — SMEM load no longer bottleneck. Tensor pipe rises to 72% of peak (from 71%). FP4 4096×7168×2048: 1094T → 1144T (+4.6%) FP4 4096×2112×7168: 1138T → 1278T (+12.3%) FP4 4096×24576×1536: 1132T → 1168T (+3.2%, exceeds CUTLASS 1146T) FP8 also improves: 597→634T, 587→622T, 637→665T
Update MMA efficiency model to better account for epilogue overhead and A fragment reuse with cooperative warp layout. Larger BN=128 has half the tile count (fewer epilogues) and 2x compute per K-block (better TMA latency amortization), outweighing the stage count reduction (4 vs 7). FP4 4096×7168×2048: 1144T → 1201T (+5.0%) FP4 4096×2112×7168: 1278T → 1307T (+2.3%)
This was referenced May 7, 2026
K-grouped FP4 was an SM120-exclusive feature not needed by DeepSeek V4 (which uses M-grouped FP4 for routed MoE). SM100 never supported it. Add assert to block FP4 inputs in k_grouped_fp8_gemm_nt_contiguous. Remove FP4 test configs from test_k_grouped_fp8.py.
SM120 mma.sync cannot natively handle MN-major operands (ldmatrix.trans only works at b16 granularity, incompatible with 1-byte FP8). Convert MN-major to K-major at the API dispatch layer: - FP8: .contiguous() transpose copy - FP4: fp4_repack_to_k_major() — unpack nibbles, rearrange, repack (simple .contiguous() corrupts FP4 nibble pairing) Also fixes SF transpose kernel SMEM overflow on SM120 (99KB limit) for large SF_K by dynamically computing block_mn from device SMEM capacity.
|
@leavelet this is nice, after it is merged into the nv-dev branch, should vllm-project/vllm#41834 merge too in order for vllm to be able to work with the branch. |
Contributor
|
I add my benchmark at vllm-project/vllm#41834 as references |
jasl
added a commit
to jasl/tokenspeed
that referenced
this pull request
May 12, 2026
Replace the hand-written CUDA FP8 GEMV kernel (previously gated to tokens==1) with a port of the SM120 FP8 einsum kernel from upstream DeepGEMM's WIP SM120 support (deepseek-ai/DeepGEMM#324, file `deep_gemm/include/deep_gemm/impls/sm120_fp8_einsum.cuh`). The DeepGEMM kernel implements exactly the `bhr,hdr->bhd` einsum DeepSeek V4 needs, with per-thread-per-output-cell GEMV using fp8x4 vectorized loads and the same block-128 fp32 scale recipe. Removing the tokens==1 gate: the kernel handles all token counts that the SM12x dispatch predicate accepts (tokens <= 16 today; larger token batches will arrive once T1-α expands graph capture). Microbench (DSv4-Flash decode shape, groups=8, hidden=2048, out=1024, GPU idle): tokens=1 cuda 0.026ms triton 0.020ms speedup 0.72x (was 0.40x) tokens=2 cuda 0.027ms triton 0.020ms speedup 0.72x (was 0.73x*) tokens=8 cuda 0.075ms triton 0.020ms speedup 0.27x (was 0.21x*) * Triton-as-default after the previous tokens==1 hotfix. The kernel's grid is `tokens * groups * (out/128)`, one block per `(token, group, out_tile=128)` triple. Because each block reads its weight tile independently, total weight reads scale linearly with `num_tokens`. At graph bs=2 (today) this dominates: tokens<=2 is the production shape and the 0.72x is a real net win against the previous Triton-fallback default. At tokens=8 (future, post-T1-α) the kernel loses ~2x to Triton's m=16 cooperative tile; we will revisit with a multi-token tile design before T1-α exposes that shape to production. Earlier hand-written attempts (one-cell-per-block, per-thread B=16 accumulator tile, 1-warp m16n8 MMA, 4-warp m16n32 cooperative MMA, 4-warp m16n128 MMA) are documented in `docs/notes/2026-05-09-ds4-sm12x-rejected-experiments.md`. The MMA designs hit either occupancy collapse (80 regs/thread) or insufficient parallelism (64 blocks at decode shape vs Blackwell's 140 SMs), capping out at ~0.51x. The DeepGEMM design wins at the production shape by avoiding tensor cores entirely -- a per-thread GEMV with fp8x4 vectorization and L1/L2-friendly weight access fits the small-M decode profile better than the m=16 MMA tile. Attribution: kernel source ported under MIT license from upstream DeepGEMM (Copyright (c) 2025 DeepSeek). Tokenspeed adaptations are the tvm-ffi binding, stride/scale validation, and the SM12x dispatch integration; the dot-product math is unchanged. Signed-off-by: jasl <jasl9187@hotmail.com>
jasl
added a commit
to jasl/tokenspeed
that referenced
this pull request
May 12, 2026
Upstream PR lightseekorg#93 added a pre-flight DeepGEMM ``fp8_gemm_nt`` call to ``DeepseekV4Attention._compute_qr_kv``: on success it replaces the reference FP8 linear path, on failure it logs a WARNING per layer and falls back. DeepGEMM does not support SM120/SM121 yet (see PR ``deepseek-ai/DeepGEMM#324`` + ``reference_deepgemm_sm120`` memory), so on the RTX Pro 6000 workstation every layer fires: DeepSeek V4 DeepGEMM FP8 linear failed; falling back to reference FP8 linear. reason=RuntimeError: Assertion error (csrc/apis/layout.hpp:59): Unknown SF transformation The existing per-layer ``_deepseek_v4_deep_gemm_linear_disabled`` flag already catches this for steady-state replay, but it costs one failed call + one WARNING per layer at boot. Mirror the pattern used by ``_deepseek_v4_deepgemm_fp4_indexer_enabled_for_platform``: short- circuit ``_deepseek_v4_get_fp8_linear_deep_gemm`` to ``None`` on SM12x so the platform never tries the DeepGEMM path. Non-SM12x platforms keep the new fast path. Signed-off-by: jasl <jasl9187@hotmail.com>
FP8 paged kernel hardcodes 2-group layout (BLOCK_KV=64), so block_kv=32 causes SMEM overflow. FP4 paged kernel correctly computes num_groups from block_kv. Add assert in FP8 launcher, enable block_kv=32 only for FP4 in test. Remove redundant test_mqa_logits.py wrapper.
…text hoist Restructure the FP8/FP4 mainloop to reduce non-QMMA instruction overhead: - SF-major loop: load SF packed int32 into registers once per kSFTileKBlocks K-blocks, extract bytes with compile-time index via cute::for_each(make_int_sequence). Eliminates redundant SF SMEM loads (10 LDS/K-block → 2.5 amortized). - Compile-time SF byte extraction: cute::for_each guarantees kb_inner is a compile-time type, enabling the compiler to specialize byte 0 (LOP3 only) and byte 3 (SHF only). - SwizzleContext hoist: move a_ctx/b_ctx initialization outside the K-block loop (loop-invariant). - Generalized sf_byte formula: (kb_inner * BLOCK_K / kGranK) % 4 supports both BLOCK_K=128 and BLOCK_K=64. - if constexpr(kUseSFMajorLoop) separates the optimized path (gran_k >= BLOCK_K) from the original flat loop (gran_k < BLOCK_K), eliminating ~300 lines of tail code duplication. SASS improvement (BLOCK_K=128): total mainloop instructions 587→469, QMMA ratio 43.6%→54.6%, IMAD 13→0, LOP3 71→34.
Add BLOCK_K=64 as a heuristic candidate for M >= 2048 non-mixed dtype GEMMs. With BLOCK_K=64, per-stage SMEM halves (17KB vs 34KB), enabling 4 pipeline stages (vs 2 with BLOCK_K=128). The deeper pipeline provides 3 stages of TMA latency slack (vs 1), improving sustained MMA pipe utilization. Benchmark (M=4096, FP8 dense): - 4096×2112×7168: 591T → 624T (+5.7%) - 4096×5120×8192: 669T → 699T (+4.6%) - 4096×5120×25600: 689T → 731T (+6.0%) Constraints: - M < 2048: BK=128 is faster (pipeline fill/drain overhead dominates) - Mixed FP8×FP4: BK=128 only (FP4 .b4x16_p64 TMA incompatible) - B64 swizzle: consequence of BK=64, verified correct on SM120a
The grouped contiguous MoE path on SM120 previously hardcoded BLOCK_M=128 in `get_layout_candidates`, and `get_theoretical_mk_alignment_for_contiguous_layout` matched it with a 128-only floor. For batch=1 decode with 128 experts and topk=6, this forces M_sum = 6 + 128*127 = 16384 padded rows of which only 6 are valid — kernel time scales with M_sum, so ~99% of work is on padding blocks the kernel still has to schedule through. Enable BLOCK_M=64 as a candidate (alongside 128 plus runtime_align if it sits on a 16-aligned value in (64,128]); the existing cycle-based scorer in `compare()` picks the smaller tile when expected_m is small. The kernel-side cooperative-warp layout (kMWarps=4, MMA_M=16, kMTilesPerWarp = BLOCK_M/64) requires BLOCK_M >= 64, so 64 is the practical floor for now; unlocking BLOCK_M=32 needs the kMWarps refactor. Microbenched on RTX PRO 6000 Blackwell at DSv4-Flash MoE shapes (N=4096/2048, K=4096/2048, num_groups=128). FC1 grouped-GEMM time medians at batch=1 decode dropped from 57 us → 37 us (-35%) and the end-to-end vLLM TPOT at TP=2 / 8K-1K dropped from 15.76 ms → 12.48 ms, beating the Marlin MoE path baseline (13.96 ms) by ~11%.
Lifts the BLOCK_M floor from 64 to 32 by switching the kernel's
cooperative warp layout from the hardcoded `kNWarps = 2` (4 M-warps ×
2 N-warps) to a BLOCK_M-dependent value:
BLOCK_M < 64 → kNWarps = 4, kMWarps = 2 (smaller A tile)
BLOCK_M >= 64 → kNWarps = 2, kMWarps = 4 (default, more A reuse)
This keeps kMTilesPerWarp = BLOCK_M / kMWarps / MMA_M = 1 in both cases
and adds a static assert. For the kNWarps=4 path we need BLOCK_N % 32 == 0
so kNTiles = BLOCK_N/8 is divisible by kNWarps; sm120.hpp filters
(block_m, block_n) accordingly.
The runtime alignment helper's SM120 min_block_m drops from 64 to 32
so `get_theoretical_mk_alignment_for_contiguous_layout(expected_m)`
returns 32 for `expected_m <= 32` (e.g. MoE decode batch=1, topk=6:
`expected_m = M*topk = 6` → align=32 → M_sum = 6 + 6*31 = 192, vs 384
at BLOCK_M=64, or the original 16384).
To avoid the cycle-based `compare()` scorer over-selecting BLOCK_M=32
for prefill-sized expected_m (where the launch/pipeline overhead of
many tiny tiles isn't in the model), small-BLOCK_M candidates are
filtered to `expected_m <= 4 * BLOCK_M`. That keeps batch=1..8 on
BLOCK_M=32 and batch≥16 on BLOCK_M=128.
Microbench on RTX PRO 6000 Blackwell, DSv4-Flash MoE shapes (N=4096,
K=4096, num_groups=128), FC1 grouped-GEMM time medians:
batch M_sum BLOCK_M=64 BLOCK_M=32+gate Marlin baseline
1 192 37.4 us 29.0 us 31.2 us
2 384 55.9 37.3 54.9
4 768 168.5 56.9 135.6
8 2304 337.6 151.2 247.4
32 16512 949.8 943.2 589.5
1024 22400 1251.9 1210.9 833.1
End-to-end vLLM bench (DSv4-Flash, TP=2, 8K/1K decode), median TPOT:
Marlin baseline: 13.96 ms
DEEPGEMM_MXFP4
+ BLOCK_M=64: 12.48 ms (-11% vs Marlin)
+ BLOCK_M=32 gated: (pending — verify e2e in vLLM)
…p layout" This reverts commit 8fb3f60.
Mirrors TRT-LLM blockscale-gemm runGemmSwapAB
(cpp/include/tensorrt_llm/deep_gemm/fp8_gemm.cuh). The SM120 1d1d kernel
has BLOCK_M >= 64, so decode-shape BMM with M_orig <= 32 wastes most M
lanes. The caller in apis/einsum.hpp::fp8_bmm now detects this case,
swaps A<->B (and sfa<->sfb before the SF layout transform), and calls
sm120_fp8_fp4_bmm with swap_ab=true. The heuristic relaxes the BLOCK_N
floor only when desc.n <= 32 (post-swap signature), so the normal path
keeps BN in {64, 128}.
The kernel takes a new runtime stride_cd_n. In normal mode it is 1 (cols
adjacent, pair-write fast path). In swap mode it is d.stride(-2) = N_orig,
so the kernel's (m_kernel=n_orig, n_kernel=m_orig) writes land byte-for-
byte in the caller's (B, M_orig, N_orig) row-major buffer with no temp
allocation. The direct-store epilogue picks pair vs single-element based
on stride_cd_n == 1. swizzle_cd_mode is forced to 0 below BN*cd_size=128
so kUseTMAStoreEpilogue is off in swap mode (the natural TMA-CD layout
can't write the swapped indices correctly).
Correctness: bit-identical output to jasl Triton einsum across M=1..32
(torch.equal == True, max abs diff 0).
Perf: ~25% faster than jasl at kernel level (ncu) across M=1..16.
Normal path (M >= 33) unchanged.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Wires the same swap-AB dispatch into apis/gemm.hpp::fp8_fp4_gemm_nt so decode-shape MLP / projection GEMMs (M_orig <= 32) also bypass the BLOCK_M = 64 floor. sm120_fp8_fp4_gemm_1d1d gains the swap_ab flag and runtime stride_cd_n the bmm variant already used, so the kernel writes directly into the caller's (M_orig, N_orig) row-major buffer with no temp allocation. Recipes can be asymmetric on plain GEMM (per-token A + per-block B defaults to (1, 128, 128) on SM12). When we swap operands we must also swap the recipe's gran_mn entries; otherwise transform_sf_into_required trips on sf.size(-2) != ceil_div(mn, gran_mn). Both gemm.hpp and einsum.hpp now build a swapped recipe and use the recipe_a/recipe_b form of transform_sf_pair_into_required_layout. einsum.hpp was correct for DSv4's symmetric (1, 1, 128) recipe before this change but would have silently produced shape errors on any asymmetric caller. Verified: 11/11 plain-GEMM correctness shapes at M=1..32 (swap) and M=64,128 (normal) within FP8 tolerance; all upstream SM120 regression tests pass (8/8 dense FP8, 7/7 k-grouped, 34/34 m-grouped FP8/FP4); einsum BMM swap path remains bit-identical to jasl Triton at M=1..32. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jasl
added a commit
to jasl/vllm
that referenced
this pull request
May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
jasl
added a commit
to jasl/vllm
that referenced
this pull request
May 18, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
jasl
added a commit
to jasl/vllm
that referenced
this pull request
May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
jasl
added a commit
to jasl/vllm
that referenced
this pull request
May 19, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
jasl
added a commit
to jasl/vllm
that referenced
this pull request
May 20, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
|
nice work! may I ask the hardware for testing here is either 5090 or RTX6000pro? |
jasl
added a commit
to jasl/vllm
that referenced
this pull request
May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
jasl
added a commit
to jasl/vllm
that referenced
this pull request
May 22, 2026
Widen the direct FP8 MQA logits Triton fallback from BLOCK_M=8 to BLOCK_M=16 while keeping BLOCK_N=128 and the existing 4-warp launch. This reduces CTA count for late-context prefill without introducing a runtime switch. The direction was motivated by the tile-shape discussion in deepseek-ai/DeepGEMM#324, but this is a vLLM-owned Triton fallback adjustment and does not copy DeepGEMM code. On the SM120 long-context gate with prefix cache disabled, the 128K synthetic mean TTFT improved from 36.541s to 33.264s at C=1, 56.902s to 49.199s at C=2, and 96.317s to 82.181s at C=4. GSM8K exact_match_flexible stayed at 0.95. Signed-off-by: jasl <jasl9187@hotmail.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Added sm120 support for DeepGEMM, with performance on par with cuBLAS and outperforming the CUTLASS version in most cases.
Important: This repository is still a work in progress and has not been cleaned up yet. It is being released early to prevent duplicate effort in the community.
sm120 support will be maintained in the
nv_devbranch by NVIDIA DevTech APAC.Performance metrics:
Dense GEMM
Grouped GEMM (FP8 / FP4)
Einsum (Batched GEMM)
bhr,hdr->bhd)bhd,hdr->bhr)bmk,bnk->mn)ldmatrix.trans.x2+ multi-atomBLOCK_N=128optimization.HC Prenorm (Fused GEMM + sqr_sum, TF32)
MQA Logits (Attention, FP8)
Co-authored by: @lucifer1004