Skip to content

feat: MegaMOE adaptation for SM90#323

Open
qiushixiaoyu wants to merge 5 commits into
deepseek-ai:mainfrom
qiushixiaoyu:main
Open

feat: MegaMOE adaptation for SM90#323
qiushixiaoyu wants to merge 5 commits into
deepseek-ai:mainfrom
qiushixiaoyu:main

Conversation

@qiushixiaoyu
Copy link
Copy Markdown

@qiushixiaoyu qiushixiaoyu commented Apr 30, 2026

Add mega moe support for sm90.
Use the following command to test:

python tests/test_mega_moe_sm90.py --layers 1 2 3 --num-processes 8 --fail-fast
python tests/test_mega_moe_sm90.py --layers 4 --num-processes 8 --fail-fast
python tests/test_mega_moe_sm90.py --layers 5 --num-correctness-tests 16 --num-processes 8

Co-authored with AI

@qinqinwo
Copy link
Copy Markdown

qinqinwo commented May 7, 2026

Do you have benchmark data?

@qiushixiaoyu
Copy link
Copy Markdown
Author

qiushixiaoyu commented May 11, 2026

Do you have benchmark data?

I’m testing the benefits of DeepSeek V4 Flash on H20, and I’ll share the data soon.

@Stone749990226
Copy link
Copy Markdown

看起来效果并不理想:
CUDA_VISIBLE_DEVICES=0,2,3,4 python tests/test_mega_moe_hopper.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 3072 --num-experts 384 --num-topk 6 --num-bench-tests 30
Config (H200 fused mega-MoE):

Tokens: 8192/8192
Hidden: 7168, Intermediate: 3072
Experts: 6/384 (per-rank: 96)
Activation SF: fused L2 per-64 UE8M0, baseline L2 per-128 UE8M0 (SM90 grouped GEMM constraint)
Buffer: 4.268 GiB

Performance:

[fused] EP 3/4 | 245 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 0/4 | 243 TFLOPS | overlap: 244 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 2/4 | 244 TFLOPS | overlap: 245 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26619 us, reduction: 126.5 us
[fused] EP 1/4 | 244 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26618 us, reduction: 126.5 us
[baseline] EP 2/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9638 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 3/4 | 676 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9640 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 1/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9642 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 0/4 | 669 TFLOPS | HBM 800 GB/s, NVL 110 GB/s | 9653 us | t_baseline/t_fused = 0.36x (baseline 更快)

"""
H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。

结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8:
  * fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`),
           使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。
  * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine,
              使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation
              per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA
              同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照,
              不是 bitwise apples-to-apples correctness oracle。
  * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us /
                  reduction us / `t_baseline / t_fused` legacy 比。
"""

import deep_ep
import argparse
import math
import os
import random
import torch
import torch.distributed as dist
import triton
import triton.language as tl
from typing import Tuple

import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench_kineto


# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名,
# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段
SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl"


# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准
FP8_E4M3_MAX = 448.0
# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例,
# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。
_FP8_E4M3_MAX_TL = tl.constexpr(448.0)
L1_ACT_SF_GRAN = 128
FUSED_L2_ACT_SF_GRAN = 64
BASELINE_L2_ACT_SF_GRAN = 128
WEIGHT_SF_GRAN_MN = 128
WEIGHT_SF_GRAN_K = 128


# ============================================================================
# 模块 1:Triton SwiGLU + FP8 量化内核
# ----------------------------------------------------------------------------
# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按
# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则,
# 避免再额外引入 exact-FP32-scale 差异。
# 输入  x        : (M, 2*H) bf16,内层是 [gate_part | up_part]
# 输入  topk_w   : (M,)     fp32,可选
# 输出  y        : (M, H)   fp8_e4m3fn
# 输出  y_sf     : (M, H/BLOCK_K) fp32 行主序
# ============================================================================


@triton.jit
def _swiglu_apply_weight_to_fp8_kernel(
    x_ptr,
    topk_w_ptr,
    y_ptr,
    y_sf_ptr,
    M,
    H,  # 运行时形状
    stride_xm,
    stride_xn,  # x: (M, 2H) 的 stride
    stride_ym,
    stride_yn,  # y: (M, H)  的 stride
    stride_sfm,
    stride_sfk,  # y_sf: (M, H/BLOCK_K) 的 stride
    clamp_value,  # 当 HAS_CLAMP=False 时这个参数无意义
    HAS_TOPK: tl.constexpr,
    HAS_CLAMP: tl.constexpr,
    USE_UE8M0_SCALE: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_K: tl.constexpr,  # = num_per_channels
):
    # 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列)
    pid_m = tl.program_id(0)
    pid_k = tl.program_id(1)

    # 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # 当前 K-block 内的列索引(在 H 维度,不是 2H)
    offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
    mask_m = offs_m < M

    # ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))----
    # 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的
    gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn
    up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn
    gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
    up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)

    # ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)----
    if HAS_CLAMP:
        gate = tl.minimum(gate, clamp_value)
        up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value)

    # ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)----
    y = gate * tl.sigmoid(gate) * up

    # ---- 4) 可选 MoE 权重缩放(per-token 标量)----
    if HAS_TOPK:
        w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0)
        y = y * w[:, None]

    # ---- 5) 当前 K-block 内每行 absmax → scale ----
    amax = tl.max(tl.abs(y), axis=1)  # (BLOCK_M,)
    sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30)
    if USE_UE8M0_SCALE:
        # 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv:
        # scale = 2 ** ceil(log2(amax / 448)).
        sf = tl.exp2(tl.ceil(tl.log2(sf)))

    # ---- 6) 量化为 FP8 e4m3fn ----
    y_fp8 = (y / sf[:, None]).to(tl.float8e4nv)

    # ---- 7) 写回 y 和 sf ----
    y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn
    tl.store(y_ptrs, y_fp8, mask=mask_m[:, None])

    sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk
    tl.store(sf_ptrs, sf, mask=mask_m)


def swiglu_apply_weight_to_fp8_triton(
    x: torch.Tensor,
    topk_weights: torch.Tensor | None,
    clamp_value: float | None = None,
    num_per_channels: int = BASELINE_L2_ACT_SF_GRAN,
    use_ue8m0_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """SwiGLU + FP8 量化。语义等价于 PyTorch reference:
    gate, up = x[:, :H], x[:, H:]
    y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w
    y_sf = y.view(M, H/np, np).abs().amax(-1) / 448
    if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf)
    y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8)
    """
    assert x.is_cuda and x.dtype == torch.bfloat16
    assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位"
    M, two_H = x.shape
    H = two_H // 2
    assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍"

    y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device)
    y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device)

    # BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调
    BLOCK_M = 16
    grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels)

    # HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位
    topk_ptr = topk_weights if topk_weights is not None else x

    _swiglu_apply_weight_to_fp8_kernel[grid](
        x,
        topk_ptr,
        y,
        y_sf,
        M,
        H,
        x.stride(0),
        x.stride(1),
        y.stride(0),
        y.stride(1),
        y_sf.stride(0),
        y_sf.stride(1),
        float(clamp_value) if clamp_value is not None else 0.0,
        HAS_TOPK=topk_weights is not None,
        HAS_CLAMP=clamp_value is not None,
        USE_UE8M0_SCALE=use_ue8m0_scale,
        BLOCK_M=BLOCK_M,
        BLOCK_K=num_per_channels,
    )
    return y, y_sf


# ============================================================================
# 模块 2:grouped weight 的 (128, 128) FP8 块量化
# ----------------------------------------------------------------------------
# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定:
#   每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。
# 与 SM100 FP4 路径的差异:
#   * 不需要 deep_gemm.transform_sf_into_required_layout
#   * SF 是 FP32,不是 UE8M0 packed
# ============================================================================


def _quantize_grouped_fp8_block_128_128(
    w: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。"""
    g, n, k = w.shape
    assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数"

    # 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部
    w_view = w.view(g, n // 128, 128, k // 128, 128).float()

    # 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块
    amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4)  # (G, N/128, K/128)
    sf = amax / FP8_E4M3_MAX

    # 量化:每个元素除以所属子块的 sf 后转 FP8
    # sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度
    w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn)
    return w_fp8.view(g, n, k).contiguous(), sf.contiguous()


# ============================================================================
# 模块 3:尝试导入 deep_ep(用于 dispatch / combine)
# ============================================================================


def _import_deep_ep():
    try:
        import deep_ep

        return deep_ep
    except Exception as ex:
        dist_print(f"Failed to import deep_ep: {ex}", once_in_node=True)
        return None


# ============================================================================
# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖)
# ============================================================================


def _bench_cuda_events(
    fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0
) -> float:
    """返回 fn 的中位数耗时(秒)。"""
    for _ in range(num_warmup):
        fn()
    torch.cuda.synchronize()
    times_ms = []
    for _ in range(num_repeat):
        # L2 flush,避免重复访问命中 cache 让测时偏低
        if l2_flush_gb > 0:
            free_bytes, _ = torch.cuda.mem_get_info()
            flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5))
            if flush_bytes >= 4:
                torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        fn()
        e.record()
        e.synchronize()
        times_ms.append(s.elapsed_time(e))
    times_ms.sort()
    return times_ms[len(times_ms) // 2] / 1e3


# ============================================================================
# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline
# ============================================================================


def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
    # 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group
    rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
    torch.manual_seed(rank_idx)
    random.seed(rank_idx)

    # 形状参数(与 test_mega_moe.py 同名同义)
    num_max_tokens_per_rank = args.num_max_tokens_per_rank
    num_tokens = args.num_tokens if args.num_tokens > 0 else num_max_tokens_per_rank
    hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
    num_experts, num_topk = args.num_experts, args.num_topk
    num_experts_per_rank = num_experts // num_ranks
    assert num_tokens <= num_max_tokens_per_rank
    assert num_experts % num_ranks == 0, (
        f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除"
    )

    # SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe):
    #   * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF)
    #   * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列)
    assert hidden % 128 == 0
    assert intermediate_hidden % 128 == 0
    assert intermediate_hidden // 64 <= 64, (
        f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}"
    )

    # ---- 创建 BF16 输入:token 与两层 weight ----
    # x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维
    x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
    # L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起)
    l1_weights_bf16 = torch.randn(
        (num_experts_per_rank, intermediate_hidden * 2, hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )
    # L2 weight: 每个 expert 把 intermediate_hidden → hidden
    l2_weights_bf16 = torch.randn(
        (num_experts_per_rank, hidden, intermediate_hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )

    # 路由:scores → topk_idx (M, K) + topk_weights (M, K)
    scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda")
    topk_weights, topk_idx = torch.topk(
        scores, num_topk, dim=-1, largest=True, sorted=False
    )

    # 累计接收统计:fused 与 baseline 各持一份避免相互覆盖
    cum_stats_fused = torch.zeros(
        (num_experts_per_rank,), dtype=torch.int, device="cuda"
    )
    cum_stats_baseline = cum_stats_fused.clone()

    # ---- BF16 → FP8 量化 ----
    # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序)
    # 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF
    x_fp8 = per_token_cast_to_fp8(
        x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False
    )

    # weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF)
    # baseline(DeepEP grouped GEMM)直接用这两个未变换的元组
    l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16)
    l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16)

    # fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变
    transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90(
        l1_weights, l2_weights
    )

    # SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致)
    clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None

    # ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)----
    alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
    deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)

    # ---- 分配 fused 的 SymmBuffer 与输出 buffer ----
    sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
        group,
        num_experts,
        num_max_tokens_per_rank,
        num_topk,
        hidden,
        intermediate_hidden,
    )
    y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")

    def run_fused():
        # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时
        # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入
        sym_buffer.x[:num_tokens].copy_(x_fp8[0])
        sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1])
        sym_buffer.topk_idx[:num_tokens].copy_(topk_idx)
        sym_buffer.topk_weights[:num_tokens].copy_(topk_weights)

        deep_gemm.fp8_mega_moe(
            y_fused,
            transformed_l1,
            transformed_l2,
            sym_buffer,
            cumulative_local_expert_recv_stats=cum_stats_fused,
            recipe=(128, 128, 128),
            activation="swiglu",
            activation_clamp=clamp_arg,
            fast_math=bool(args.fast_math),
        )
        return y_fused

    # ---- 分配 DeepEP buffer(baseline 用)----
    deep_ep = _import_deep_ep()
    ep_buffer = None
    if deep_ep is not None:
        ep_buffer = deep_ep.ElasticBuffer(
            group,
            num_max_tokens_per_rank=num_max_tokens_per_rank,
            hidden=hidden,
            num_topk=num_topk,
            use_fp8_dispatch=True,
            explicitly_destroy=True,
            allow_multiple_reduction=False,
        )

    # ----------------------------------------------------------------
    # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine
    # 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换**
    # 的版本(baseline grouped GEMM 不需要 gate/up interleave)
    # ----------------------------------------------------------------
    def run_baseline():
        recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
            x_fp8,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
            cumulative_local_expert_recv_stats=cum_stats_baseline,
            num_experts=num_experts,
            expert_alignment=alignment,
            do_cpu_sync=False,
            do_handle_copy=False,
            do_expand=True,
            use_tma_aligned_col_major_sf=False,  # SM90: row-major float SF
        )
        n = recv_x[0].size(0)

        # L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接)
        l1_y = torch.empty(
            (n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda"
        )
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            recv_x,
            l1_weights,
            l1_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # Triton SwiGLU + FP8 量化(含 topk 权重乘法)
        # 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K;
        # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline
        # 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。
        l1_y = swiglu_apply_weight_to_fp8_triton(
            x=l1_y,
            topk_weights=recv_topk_weights,
            clamp_value=clamp_arg,
            num_per_channels=BASELINE_L2_ACT_SF_GRAN,
            use_ue8m0_scale=True,
        )

        # L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16
        l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda")
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            l1_y,
            l2_weights,
            l2_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank
        return ep_buffer.combine(l2_y, handle=handle)[0]

    # ---- 打印 config ----
    dist_print("Config (H200 fused mega-MoE):", once_in_node=True)
    dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True)
    dist_print(
        f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True
    )
    dist_print(
        f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})",
        once_in_node=True,
    )
    dist_print(
        f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, "
        f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 "
        f"(SM90 grouped GEMM constraint)",
        once_in_node=True,
    )
    dist_print(
        f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True
    )
    dist_print(once_in_node=True)

    # ---- 跑一次确保不报错(fused + 可选 baseline)----
    y = run_fused()
    assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, (
        f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}"
    )
    if ep_buffer is not None:
        out_b = run_baseline()
        assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, (
            f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}"
        )
        if args.check_output_diff:
            diff = (y.float() - out_b.float()).abs()
            denom = out_b.float().abs().mean().clamp_min(1e-12)
            dist_print(
                "Output diff (fused vs legacy-per128 baseline):", once_in_node=True
            )
            dist_print(
                f" > max_abs={diff.max().item():.6e}, "
                f"mean_abs={diff.mean().item():.6e}, "
                f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}",
                once_in_node=True,
            )
            dist_print(once_in_node=True)

    # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ----
    # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目
    # 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。
    gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
    gathered_topk_idx[
        (gathered_topk_idx < rank_idx * num_experts_per_rank)
        | (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)
    ] = -1
    num_recv_tokens = int((gathered_topk_idx != -1).sum().item())
    num_touched_experts = max(torch.unique(gathered_topk_idx.flatten()).numel() - 1, 0)

    # ---- benchmark ----
    # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead)
    t_fused = bench_kineto(
        run_fused,
        SM90_KERNEL_NAME,
        num_tests=args.num_bench_tests,
        barrier=lambda: ep_buffer.barrier(use_comm_stream=False)
        if ep_buffer is not None
        else dist.barrier(),
        trace_path=(
            f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json"
            if args.dump_profile_traces
            else None
        ),
    )
    # baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events)
    t_baseline = (
        _bench_cuda_events(
            run_baseline,
            num_warmup=args.num_warmup,
            num_repeat=args.num_repeat,
            l2_flush_gb=args.l2_flush_gb,
        )
        if ep_buffer is not None
        else 0.0
    )

    def safe_div(a, b):
        return float("nan") if b == 0 else a / b

    # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens
    tflops = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused
    )

    # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同)
    l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden
    l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden
    l1_weight_sf_bytes = (
        num_touched_experts
        * (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN)
        * (hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l2_weight_sf_bytes = (
        num_touched_experts
        * (hidden // WEIGHT_SF_GRAN_MN)
        * (intermediate_hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4
    l2_act_sf_bytes = (
        num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4
    )
    num_hbm_bytes = (
        l1_weight_bytes
        + l2_weight_bytes  # weights (FP8)
        + l1_weight_sf_bytes
        + l2_weight_sf_bytes  # weight SF (FP32)
        + num_recv_tokens * hidden
        + l1_input_sf_bytes  # L1 输入读 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L1 输出写 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L2 输入读 (FP8 + SF)
        + num_recv_tokens * hidden * 2  # L2 输出写 (BF16)
    )
    hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)

    # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16
    num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2)
    nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)

    # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s)
    t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12

    # overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐
    approx_factor = t_fused / max(t_fused - t_reduction, 1e-12)

    # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline
    tflops_baseline = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline
    )
    hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline)
    nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline)

    dist_print("Performance:", once_in_node=True)
    dist_print(
        f" > [fused]    EP {rank_idx:2}/{num_ranks} | "
        f"{tflops:4.0f} TFLOPS | "
        f"overlap: {tflops * approx_factor:4.0f} TFLOPS, "
        f"HBM {hbm_gbs * approx_factor:4.0f} GB/s, "
        f"NVL {nvlink_gbs * approx_factor:3.0f} GB/s | "
        f"{t_fused * 1e6:6.0f} us, "
        f"reduction: {t_reduction * 1e6:5.1f} us"
    )
    if ep_buffer is not None:
        speedup = safe_div(t_baseline, t_fused)
        dist_print(
            f" > [baseline] EP {rank_idx:2}/{num_ranks} | "
            f"{tflops_baseline:4.0f} TFLOPS | "
            f"               HBM {hbm_gbs_baseline:4.0f} GB/s, "
            f"NVL {nvlink_gbs_baseline:3.0f} GB/s | "
            f"{t_baseline * 1e6:6.0f} us | "
            f"t_baseline/t_fused = {speedup:.2f}x "
            f"({'fused 更快' if speedup > 1 else 'baseline 更快'})"
        )
    else:
        dist_print(" > [baseline] (no baseline: deep_ep unavailable)", once_in_node=True)

    # ---- 清理 ----
    dist.barrier()
    sym_buffer.destroy()
    if ep_buffer is not None:
        ep_buffer.destroy()
    dist.destroy_process_group()


# ============================================================================
# 模块 6:argparse + spawn
# ============================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline"
    )

    # 资源
    parser.add_argument(
        "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)"
    )

    # 模型形状
    # 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096
    parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192)
    parser.add_argument(
        "--num-tokens",
        type=int,
        default=0,
        help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank",
    )
    parser.add_argument("--hidden", type=int, default=7168)
    parser.add_argument(
        "--intermediate-hidden",
        type=int,
        default=3072,
        help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)",
    )
    parser.add_argument(
        "--activation-clamp",
        type=float,
        default=10.0,
        help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭",
    )
    parser.add_argument("--num-experts", type=int, default=384)
    parser.add_argument("--num-topk", type=int, default=6)
    parser.add_argument(
        "--fast-math",
        type=int,
        default=1,
        help="fused 内 SwiGLU 是否启用 fast-math(0/1)",
    )

    # 测时
    parser.add_argument(
        "--num-bench-tests",
        type=int,
        default=30,
        help="bench_kineto 抓 fused 时的迭代数",
    )
    parser.add_argument(
        "--num-warmup", type=int, default=5, help="baseline cuda events warmup"
    )
    parser.add_argument(
        "--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代"
    )
    parser.add_argument(
        "--l2-flush-gb",
        type=float,
        default=8.0,
        help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭",
    )
    parser.add_argument(
        "--check-output-diff",
        type=int,
        default=0,
        help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)",
    )
    parser.add_argument(
        "--dump-profile-traces",
        type=str,
        default="",
        help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)",
    )

    args = parser.parse_args()

    if args.dump_profile_traces:
        os.makedirs(args.dump_profile_traces, exist_ok=True)

    # 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group
    torch.multiprocessing.spawn(
        test, args=(args.num_processes, args), nprocs=args.num_processes
    )

@MikeFang-dev
Copy link
Copy Markdown

看起来效果并不理想: CUDA_VISIBLE_DEVICES=0,2,3,4 python tests/test_mega_moe_hopper.py --num-processes 4 --num-max-tokens-per-rank 8192 --num-tokens 8192 --hidden 7168 --intermediate-hidden 3072 --num-experts 384 --num-topk 6 --num-bench-tests 30 Config (H200 fused mega-MoE):

Tokens: 8192/8192
Hidden: 7168, Intermediate: 3072
Experts: 6/384 (per-rank: 96)
Activation SF: fused L2 per-64 UE8M0, baseline L2 per-128 UE8M0 (SM90 grouped GEMM constraint)
Buffer: 4.268 GiB

Performance:

[fused] EP 3/4 | 245 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 0/4 | 243 TFLOPS | overlap: 244 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26617 us, reduction: 126.5 us
[fused] EP 2/4 | 244 TFLOPS | overlap: 245 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26619 us, reduction: 126.5 us
[fused] EP 1/4 | 244 TFLOPS | overlap: 246 TFLOPS, HBM 292 GB/s, NVL 40 GB/s | 26618 us, reduction: 126.5 us
[baseline] EP 2/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9638 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 3/4 | 676 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9640 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 1/4 | 675 TFLOPS | HBM 802 GB/s, NVL 111 GB/s | 9642 us | t_baseline/t_fused = 0.36x (baseline 更快)
[baseline] EP 0/4 | 669 TFLOPS | HBM 800 GB/s, NVL 110 GB/s | 9653 us | t_baseline/t_fused = 0.36x (baseline 更快)

"""
H200 (SM90 / Hopper) mega-MoE: fused kernel + 同管线 baseline 性能对比。

结构对齐 tests/test_mega_moe.py(B 系列 SM100 FP4 路径),但所有路径都换成 H200 FP8:
  * fused:调用 `deep_gemm.fp8_mega_moe`(kernel symbol `sm90_fp8_mega_moe_impl`),
           使用 `transform_weights_for_mega_moe_sm90` 处理过的权重 + SymmBuffer。
  * baseline:DeepEP dispatch + 2 个 grouped FP8 GEMM + Triton SwiGLU + DeepEP combine,
              使用未变换的权重。由于当前 SM90 grouped GEMM 只支持 L2 activation
              per-128-K SFA,而 fused SM90 mega-MoE 的 L1 epilogue 为避免跨 CTA
              同步使用 per-64-K SFA,所以该 baseline 是同管线 legacy 参照,
              不是 bitwise apples-to-apples correctness oracle。
  * 性能输出涵盖:TFLOPS / overlap TFLOPS / HBM GB/s / NVL GB/s / fused us /
                  reduction us / `t_baseline / t_fused` legacy 比。
"""

import deep_ep
import argparse
import math
import os
import random
import torch
import torch.distributed as dist
import triton
import triton.language as tl
from typing import Tuple

import deep_gemm
from deep_gemm.utils import per_token_cast_to_fp8
from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather
from deep_gemm.testing import bench_kineto


# 与 deep_gemm/include/deep_gemm/impls/sm90_fp8_mega_moe.cuh 中模板入口同名,
# bench_kineto 用它从 trace 里挑出 fused mega-MoE 的 GPU 段
SM90_KERNEL_NAME = "sm90_fp8_mega_moe_impl"


# FP8 e4m3fn 的最大可表示值,量化时用 amax / 448 作为 scale 基准
FP8_E4M3_MAX = 448.0
# 新版 Triton(>= 3.x)强制:jit 内核读到的 Python 全局必须是 tl.constexpr 实例,
# 否则编译期 NameError。宿主 Python 侧仍用上面的普通 float 做 torch 运算。
_FP8_E4M3_MAX_TL = tl.constexpr(448.0)
L1_ACT_SF_GRAN = 128
FUSED_L2_ACT_SF_GRAN = 64
BASELINE_L2_ACT_SF_GRAN = 128
WEIGHT_SF_GRAN_MN = 128
WEIGHT_SF_GRAN_K = 128


# ============================================================================
# 模块 1:Triton SwiGLU + FP8 量化内核
# ----------------------------------------------------------------------------
# baseline 的 L2 仍走 DeepGEMM SM90 grouped FP8 GEMM,所以 activation SFA 只能按
# per-128-K 输入;但 scale 数值采用 fused epilogue 同款 UE8M0/power-of-two 规则,
# 避免再额外引入 exact-FP32-scale 差异。
# 输入  x        : (M, 2*H) bf16,内层是 [gate_part | up_part]
# 输入  topk_w   : (M,)     fp32,可选
# 输出  y        : (M, H)   fp8_e4m3fn
# 输出  y_sf     : (M, H/BLOCK_K) fp32 行主序
# ============================================================================


@triton.jit
def _swiglu_apply_weight_to_fp8_kernel(
    x_ptr,
    topk_w_ptr,
    y_ptr,
    y_sf_ptr,
    M,
    H,  # 运行时形状
    stride_xm,
    stride_xn,  # x: (M, 2H) 的 stride
    stride_ym,
    stride_yn,  # y: (M, H)  的 stride
    stride_sfm,
    stride_sfk,  # y_sf: (M, H/BLOCK_K) 的 stride
    clamp_value,  # 当 HAS_CLAMP=False 时这个参数无意义
    HAS_TOPK: tl.constexpr,
    HAS_CLAMP: tl.constexpr,
    USE_UE8M0_SCALE: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_K: tl.constexpr,  # = num_per_channels
):
    # 一个 program 处理 (BLOCK_M 个 token) × (第 pid_k 个 K-block 的 BLOCK_K 列)
    pid_m = tl.program_id(0)
    pid_k = tl.program_id(1)

    # 行索引:本 program 负责 [pid_m*BLOCK_M, pid_m*BLOCK_M+BLOCK_M)
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # 当前 K-block 内的列索引(在 H 维度,不是 2H)
    offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)
    mask_m = offs_m < M

    # ---- 1) 载入 gate(x 的前半段 [0, H))和 up(x 的后半段 [H, 2H))----
    # 注意 stride_xn 是元素 stride(一般 == 1),但 H + offs_k 偏移是按"元素"算的
    gate_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xn
    up_ptrs = x_ptr + offs_m[:, None] * stride_xm + (H + offs_k[None, :]) * stride_xn
    gate = tl.load(gate_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)
    up = tl.load(up_ptrs, mask=mask_m[:, None], other=0.0).to(tl.float32)

    # ---- 2) 可选 clamp(参考 tilelang 实现:gate 单边 max,up 双边)----
    if HAS_CLAMP:
        gate = tl.minimum(gate, clamp_value)
        up = tl.minimum(tl.maximum(up, -clamp_value), clamp_value)

    # ---- 3) SwiGLU:silu(gate) * up = gate * sigmoid(gate) * up(全程 FP32 累计)----
    y = gate * tl.sigmoid(gate) * up

    # ---- 4) 可选 MoE 权重缩放(per-token 标量)----
    if HAS_TOPK:
        w = tl.load(topk_w_ptr + offs_m, mask=mask_m, other=1.0)
        y = y * w[:, None]

    # ---- 5) 当前 K-block 内每行 absmax → scale ----
    amax = tl.max(tl.abs(y), axis=1)  # (BLOCK_M,)
    sf = tl.maximum(amax / _FP8_E4M3_MAX_TL, 1.0e-30)
    if USE_UE8M0_SCALE:
        # 对齐 deep_gemm/common/math.cuh::get_e4m3_sf_and_sf_inv:
        # scale = 2 ** ceil(log2(amax / 448)).
        sf = tl.exp2(tl.ceil(tl.log2(sf)))

    # ---- 6) 量化为 FP8 e4m3fn ----
    y_fp8 = (y / sf[:, None]).to(tl.float8e4nv)

    # ---- 7) 写回 y 和 sf ----
    y_ptrs = y_ptr + offs_m[:, None] * stride_ym + offs_k[None, :] * stride_yn
    tl.store(y_ptrs, y_fp8, mask=mask_m[:, None])

    sf_ptrs = y_sf_ptr + offs_m * stride_sfm + pid_k * stride_sfk
    tl.store(sf_ptrs, sf, mask=mask_m)


def swiglu_apply_weight_to_fp8_triton(
    x: torch.Tensor,
    topk_weights: torch.Tensor | None,
    clamp_value: float | None = None,
    num_per_channels: int = BASELINE_L2_ACT_SF_GRAN,
    use_ue8m0_scale: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """SwiGLU + FP8 量化。语义等价于 PyTorch reference:
    gate, up = x[:, :H], x[:, H:]
    y = silu(gate.clamp(max=c)) * up.clamp(-c, c) * topk_w
    y_sf = y.view(M, H/np, np).abs().amax(-1) / 448
    if use_ue8m0_scale: y_sf = ceil_to_power_of_2(y_sf)
    y_fp8 = (y / y_sf.unsqueeze(-1)).to(fp8)
    """
    assert x.is_cuda and x.dtype == torch.bfloat16
    assert x.is_contiguous(), "当前实现假设 x 是 contiguous 的,避免 stride 计算错位"
    M, two_H = x.shape
    H = two_H // 2
    assert H % num_per_channels == 0, f"H={H} 必须是 {num_per_channels} 的整数倍"

    y = torch.empty((M, H), dtype=torch.float8_e4m3fn, device=x.device)
    y_sf = torch.empty((M, H // num_per_channels), dtype=torch.float32, device=x.device)

    # BLOCK_M 取 16:内核每个 program 处理 16 个 token × 128 列,寄存器压力小、容易调
    BLOCK_M = 16
    grid = (triton.cdiv(M, BLOCK_M), H // num_per_channels)

    # HAS_TOPK=False 时仍要传一个有效指针(Triton 不允许 nullptr),用 x 占位
    topk_ptr = topk_weights if topk_weights is not None else x

    _swiglu_apply_weight_to_fp8_kernel[grid](
        x,
        topk_ptr,
        y,
        y_sf,
        M,
        H,
        x.stride(0),
        x.stride(1),
        y.stride(0),
        y.stride(1),
        y_sf.stride(0),
        y_sf.stride(1),
        float(clamp_value) if clamp_value is not None else 0.0,
        HAS_TOPK=topk_weights is not None,
        HAS_CLAMP=clamp_value is not None,
        USE_UE8M0_SCALE=use_ue8m0_scale,
        BLOCK_M=BLOCK_M,
        BLOCK_K=num_per_channels,
    )
    return y, y_sf


# ============================================================================
# 模块 2:grouped weight 的 (128, 128) FP8 块量化
# ----------------------------------------------------------------------------
# m_grouped_fp8_gemm_nt_contiguous 在 SM90 上对 weight 的输入约定:
#   每 (128, 128) 子块共享一个 FP32 SF,K 是 SF 的内层连续维(K-major)。
# 与 SM100 FP4 路径的差异:
#   * 不需要 deep_gemm.transform_sf_into_required_layout
#   * SF 是 FP32,不是 UE8M0 packed
# ============================================================================


def _quantize_grouped_fp8_block_128_128(
    w: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """(G, N, K) bf16 → (G, N, K) fp8_e4m3fn + (G, N//128, K//128) fp32 SF。"""
    g, n, k = w.shape
    assert n % 128 == 0 and k % 128 == 0, f"weight 的 N={n}, K={k} 都必须是 128 的倍数"

    # 把 (N, K) 切成 (N/128, 128, K/128, 128),最后一维和倒数第三维就是 128×128 子块内部
    w_view = w.view(g, n // 128, 128, k // 128, 128).float()

    # 子块内 absmax → scale = amax / 448,clamp(1e-4) 避免全 0 子块
    amax = w_view.abs().amax(dim=(-1, -3)).clamp(1e-4)  # (G, N/128, K/128)
    sf = amax / FP8_E4M3_MAX

    # 量化:每个元素除以所属子块的 sf 后转 FP8
    # sf 形状 (G, N/128, K/128),需在 N-内 (axis -3) 和 K-内 (axis -1) 都补维度
    w_fp8 = (w_view / sf.unsqueeze(-1).unsqueeze(-3)).to(torch.float8_e4m3fn)
    return w_fp8.view(g, n, k).contiguous(), sf.contiguous()


# ============================================================================
# 模块 3:尝试导入 deep_ep(用于 dispatch / combine)
# ============================================================================


def _import_deep_ep():
    try:
        import deep_ep

        return deep_ep
    except Exception as ex:
        dist_print(f"Failed to import deep_ep: {ex}", once_in_node=True)
        return None


# ============================================================================
# 模块 4:CUDA event 中位数测时(避开对 tilelang.do_bench 的依赖)
# ============================================================================


def _bench_cuda_events(
    fn, num_warmup: int = 5, num_repeat: int = 20, l2_flush_gb: float = 8.0
) -> float:
    """返回 fn 的中位数耗时(秒)。"""
    for _ in range(num_warmup):
        fn()
    torch.cuda.synchronize()
    times_ms = []
    for _ in range(num_repeat):
        # L2 flush,避免重复访问命中 cache 让测时偏低
        if l2_flush_gb > 0:
            free_bytes, _ = torch.cuda.mem_get_info()
            flush_bytes = min(int(l2_flush_gb * 1e9), int(free_bytes * 0.5))
            if flush_bytes >= 4:
                torch.empty(flush_bytes // 4, dtype=torch.int, device="cuda").zero_()
        s = torch.cuda.Event(enable_timing=True)
        e = torch.cuda.Event(enable_timing=True)
        s.record()
        fn()
        e.record()
        e.synchronize()
        times_ms.append(s.elapsed_time(e))
    times_ms.sort()
    return times_ms[len(times_ms) // 2] / 1e3


# ============================================================================
# 模块 5:test() 主入口 — 在每个 rank 上跑一遍 baseline
# ============================================================================


def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
    # 初始化分布式:rank_idx 是全局 rank,group 是默认 NCCL group
    rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks)
    torch.manual_seed(rank_idx)
    random.seed(rank_idx)

    # 形状参数(与 test_mega_moe.py 同名同义)
    num_max_tokens_per_rank = args.num_max_tokens_per_rank
    num_tokens = args.num_tokens if args.num_tokens > 0 else num_max_tokens_per_rank
    hidden, intermediate_hidden = args.hidden, args.intermediate_hidden
    num_experts, num_topk = args.num_experts, args.num_topk
    num_experts_per_rank = num_experts // num_ranks
    assert num_tokens <= num_max_tokens_per_rank
    assert num_experts % num_ranks == 0, (
        f"num_experts={num_experts} 必须能被 num_ranks={num_ranks} 整除"
    )

    # SM90 fused kernel 的形状约束(来自 csrc/apis/mega.hpp::fp8_mega_moe):
    #   * H、IH 必须是 128 的倍数(L1 input per-128-K SF + block-(128,128) weight SF)
    #   * IH/64 ≤ 64 → IH ≤ 4096(l2_arrival_mask 是 uint64,每 bit 对应 64 列)
    assert hidden % 128 == 0
    assert intermediate_hidden % 128 == 0
    assert intermediate_hidden // 64 <= 64, (
        f"SM90 fused kernel 要求 intermediate_hidden <= 4096, 当前 {intermediate_hidden}"
    )

    # ---- 创建 BF16 输入:token 与两层 weight ----
    # x: 每 rank 本地 num_tokens 个 token,每个 token hidden 维
    x_bf16 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")
    # L1 weight: 每个 expert 把 hidden → 2*intermediate_hidden(gate 和 up 拼一起)
    l1_weights_bf16 = torch.randn(
        (num_experts_per_rank, intermediate_hidden * 2, hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )
    # L2 weight: 每个 expert 把 intermediate_hidden → hidden
    l2_weights_bf16 = torch.randn(
        (num_experts_per_rank, hidden, intermediate_hidden),
        dtype=torch.bfloat16,
        device="cuda",
    )

    # 路由:scores → topk_idx (M, K) + topk_weights (M, K)
    scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device="cuda")
    topk_weights, topk_idx = torch.topk(
        scores, num_topk, dim=-1, largest=True, sorted=False
    )

    # 累计接收统计:fused 与 baseline 各持一份避免相互覆盖
    cum_stats_fused = torch.zeros(
        (num_experts_per_rank,), dtype=torch.int, device="cuda"
    )
    cum_stats_baseline = cum_stats_fused.clone()

    # ---- BF16 → FP8 量化 ----
    # x_fp8 是元组:(token_fp8 (M, hidden), token_sf (M, hidden//128) fp32 行主序)
    # 注意 use_ue8m0=False, use_packed_ue8m0=False:SM90 不接受 UE8M0 packed SF
    x_fp8 = per_token_cast_to_fp8(
        x_bf16, use_ue8m0=False, gran_k=128, use_packed_ue8m0=False
    )

    # weight 量化:(G, N, K) bf16 → ((G, N, K) fp8 e4m3fn, (G, N//128, K//128) fp32 SF)
    # baseline(DeepEP grouped GEMM)直接用这两个未变换的元组
    l1_weights = _quantize_grouped_fp8_block_128_128(l1_weights_bf16)
    l2_weights = _quantize_grouped_fp8_block_128_128(l2_weights_bf16)

    # fused 路径:FP8 weight 上做 gate/up gran-8 N-轴 interleave;SF 不变
    transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe_sm90(
        l1_weights, l2_weights
    )

    # SwiGLU clamp:finite → 传给 fused/triton;inf → None(关闭 clamp,与 SM90 fused 一致)
    clamp_arg = args.activation_clamp if math.isfinite(args.activation_clamp) else None

    # ---- DeepGEMM grouped GEMM 的 M 维 alignment(baseline 走 DeepEP 时也用这个)----
    alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
    deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)

    # ---- 分配 fused 的 SymmBuffer 与输出 buffer ----
    sym_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
        group,
        num_experts,
        num_max_tokens_per_rank,
        num_topk,
        hidden,
        intermediate_hidden,
    )
    y_fused = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device="cuda")

    def run_fused():
        # NOTE: 跟 SM100 test_mega_moe.py 的处理一致 —— DG_COMM_KERNEL_DEBUG=1 时
        # kernel 出口会把 sym_buffer 整块清零,所以每次都要重新拷输入
        sym_buffer.x[:num_tokens].copy_(x_fp8[0])
        sym_buffer.x_sf[:num_tokens].copy_(x_fp8[1])
        sym_buffer.topk_idx[:num_tokens].copy_(topk_idx)
        sym_buffer.topk_weights[:num_tokens].copy_(topk_weights)

        deep_gemm.fp8_mega_moe(
            y_fused,
            transformed_l1,
            transformed_l2,
            sym_buffer,
            cumulative_local_expert_recv_stats=cum_stats_fused,
            recipe=(128, 128, 128),
            activation="swiglu",
            activation_clamp=clamp_arg,
            fast_math=bool(args.fast_math),
        )
        return y_fused

    # ---- 分配 DeepEP buffer(baseline 用)----
    deep_ep = _import_deep_ep()
    ep_buffer = None
    if deep_ep is not None:
        ep_buffer = deep_ep.ElasticBuffer(
            group,
            num_max_tokens_per_rank=num_max_tokens_per_rank,
            hidden=hidden,
            num_topk=num_topk,
            use_fp8_dispatch=True,
            explicitly_destroy=True,
            allow_multiple_reduction=False,
        )

    # ----------------------------------------------------------------
    # baseline 主体:dispatch → L1 GEMM → SwiGLU+量化 → L2 GEMM → combine
    # 与 fused 用同一份 (FP8 weight, FP32 block-(128,128) SF) —— 但是 **未变换**
    # 的版本(baseline grouped GEMM 不需要 gate/up interleave)
    # ----------------------------------------------------------------
    def run_baseline():
        recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch(
            x_fp8,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
            cumulative_local_expert_recv_stats=cum_stats_baseline,
            num_experts=num_experts,
            expert_alignment=alignment,
            do_cpu_sync=False,
            do_handle_copy=False,
            do_expand=True,
            use_tma_aligned_col_major_sf=False,  # SM90: row-major float SF
        )
        n = recv_x[0].size(0)

        # L1 GEMM:FP8 token @ FP8 W1 → BF16 中间激活 (gate||up 拼接)
        l1_y = torch.empty(
            (n, intermediate_hidden * 2), dtype=torch.bfloat16, device="cuda"
        )
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            recv_x,
            l1_weights,
            l1_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # Triton SwiGLU + FP8 量化(含 topk 权重乘法)
        # 注意:fused SM90 mega-MoE 的 L2 activation SFA 是 per-64-K;
        # 当前 DeepGEMM SM90 grouped GEMM 只支持 per-128-K SFA,所以性能 baseline
        # 只能用 per-128-K,但 scale 数值采用 fused 同款 UE8M0/power-of-two。
        l1_y = swiglu_apply_weight_to_fp8_triton(
            x=l1_y,
            topk_weights=recv_topk_weights,
            clamp_value=clamp_arg,
            num_per_channels=BASELINE_L2_ACT_SF_GRAN,
            use_ue8m0_scale=True,
        )

        # L2 GEMM:FP8 中间激活 @ FP8 W2 → BF16
        l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device="cuda")
        deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
            l1_y,
            l2_weights,
            l2_y,
            handle.psum_num_recv_tokens_per_expert,
            use_psum_layout=True,
            disable_ue8m0_cast=True,
        )

        # DeepEP combine:把每个 token 在 topk 个 expert 上的输出汇聚回源 rank
        return ep_buffer.combine(l2_y, handle=handle)[0]

    # ---- 打印 config ----
    dist_print("Config (H200 fused mega-MoE):", once_in_node=True)
    dist_print(f" > Tokens: {num_tokens}/{num_max_tokens_per_rank}", once_in_node=True)
    dist_print(
        f" > Hidden: {hidden}, Intermediate: {intermediate_hidden}", once_in_node=True
    )
    dist_print(
        f" > Experts: {num_topk}/{num_experts} (per-rank: {num_experts_per_rank})",
        once_in_node=True,
    )
    dist_print(
        f" > Activation SF: fused L2 per-{FUSED_L2_ACT_SF_GRAN} UE8M0, "
        f"baseline L2 per-{BASELINE_L2_ACT_SF_GRAN} UE8M0 "
        f"(SM90 grouped GEMM constraint)",
        once_in_node=True,
    )
    dist_print(
        f" > Buffer: {sym_buffer.buffer.nbytes / 2**30:.3f} GiB", once_in_node=True
    )
    dist_print(once_in_node=True)

    # ---- 跑一次确保不报错(fused + 可选 baseline)----
    y = run_fused()
    assert y.shape == (num_tokens, hidden) and y.dtype == torch.bfloat16, (
        f"fused 输出 shape/dtype 异常: shape={y.shape}, dtype={y.dtype}"
    )
    if ep_buffer is not None:
        out_b = run_baseline()
        assert out_b.shape == (num_tokens, hidden) and out_b.dtype == torch.bfloat16, (
            f"baseline 输出 shape/dtype 异常: shape={out_b.shape}, dtype={out_b.dtype}"
        )
        if args.check_output_diff:
            diff = (y.float() - out_b.float()).abs()
            denom = out_b.float().abs().mean().clamp_min(1e-12)
            dist_print(
                "Output diff (fused vs legacy-per128 baseline):", once_in_node=True
            )
            dist_print(
                f" > max_abs={diff.max().item():.6e}, "
                f"mean_abs={diff.mean().item():.6e}, "
                f"mean_abs/mean_ref={diff.mean().div(denom).item():.6e}",
                once_in_node=True,
            )
            dist_print(once_in_node=True)

    # ---- 统计本 rank 实际接收的 token 数与触达的 expert 数 ----
    # 把所有 rank 的 topk_idx 收齐,再把不落在本 rank 持有 expert 范围内的条目
    # 标成 -1;剩下的非 -1 条目数即"被路由进本 rank 的 (token, slot) 总数"。
    gathered_topk_idx = uneven_all_gather(topk_idx, group=group)
    gathered_topk_idx[
        (gathered_topk_idx < rank_idx * num_experts_per_rank)
        | (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)
    ] = -1
    num_recv_tokens = int((gathered_topk_idx != -1).sum().item())
    num_touched_experts = max(torch.unique(gathered_topk_idx.flatten()).numel() - 1, 0)

    # ---- benchmark ----
    # fused:bench_kineto 抓 sm90_fp8_mega_moe_impl 的 GPU 段(不含 host overhead)
    t_fused = bench_kineto(
        run_fused,
        SM90_KERNEL_NAME,
        num_tests=args.num_bench_tests,
        barrier=lambda: ep_buffer.barrier(use_comm_stream=False)
        if ep_buffer is not None
        else dist.barrier(),
        trace_path=(
            f"{args.dump_profile_traces}/mega_moe_hopper_rank{rank_idx}.json"
            if args.dump_profile_traces
            else None
        ),
    )
    # baseline:cuda events 中位数(tilelang.do_bench 在 H200 不一定有,统一用 events)
    t_baseline = (
        _bench_cuda_events(
            run_baseline,
            num_warmup=args.num_warmup,
            num_repeat=args.num_repeat,
            l2_flush_gb=args.l2_flush_gb,
        )
        if ep_buffer is not None
        else 0.0
    )

    def safe_div(a, b):
        return float("nan") if b == 0 else a / b

    # 端到端 TFLOPS:3 个 matmul(L1 gate、L1 up、L2),每个 2*M*N*K,M=num_recv_tokens
    tflops = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused
    )

    # HBM 字节估算(SM90: weight 是 FP8 = 1B/elem,与 SM100 FP4=0.5B 不同)
    l1_weight_bytes = num_touched_experts * intermediate_hidden * 2 * hidden
    l2_weight_bytes = num_touched_experts * hidden * intermediate_hidden
    l1_weight_sf_bytes = (
        num_touched_experts
        * (intermediate_hidden * 2 // WEIGHT_SF_GRAN_MN)
        * (hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l2_weight_sf_bytes = (
        num_touched_experts
        * (hidden // WEIGHT_SF_GRAN_MN)
        * (intermediate_hidden // WEIGHT_SF_GRAN_K)
        * 4
    )
    l1_input_sf_bytes = num_recv_tokens * (hidden // L1_ACT_SF_GRAN) * 4
    l2_act_sf_bytes = (
        num_recv_tokens * (intermediate_hidden // FUSED_L2_ACT_SF_GRAN) * 4
    )
    num_hbm_bytes = (
        l1_weight_bytes
        + l2_weight_bytes  # weights (FP8)
        + l1_weight_sf_bytes
        + l2_weight_sf_bytes  # weight SF (FP32)
        + num_recv_tokens * hidden
        + l1_input_sf_bytes  # L1 输入读 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L1 输出写 (FP8 + SF)
        + num_recv_tokens * intermediate_hidden
        + l2_act_sf_bytes  # L2 输入读 (FP8 + SF)
        + num_recv_tokens * hidden * 2  # L2 输出写 (BF16)
    )
    hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused)

    # NVLink 字节:dispatch 拉 token + input SF + topk weight,combine 写回 BF16
    num_nvlink_bytes = num_recv_tokens * (hidden + hidden // 32 + 4 + hidden * 2)
    nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused)

    # combine reduction 串行下界(解析估计;6.5e12 = HBM 串行 reduction 经验吞吐 B/s)
    t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12

    # overlap 校正:扣掉 fused 中无法重叠的串行 reduction 段后估计稳态吞吐
    approx_factor = t_fused / max(t_fused - t_reduction, 1e-12)

    # baseline 用同一份 FLOPs / HBM 字节,时间换成 t_baseline
    tflops_baseline = safe_div(
        2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_baseline
    )
    hbm_gbs_baseline = safe_div(num_hbm_bytes / 1e9, t_baseline)
    nvlink_gbs_baseline = safe_div(num_nvlink_bytes / 1e9, t_baseline)

    dist_print("Performance:", once_in_node=True)
    dist_print(
        f" > [fused]    EP {rank_idx:2}/{num_ranks} | "
        f"{tflops:4.0f} TFLOPS | "
        f"overlap: {tflops * approx_factor:4.0f} TFLOPS, "
        f"HBM {hbm_gbs * approx_factor:4.0f} GB/s, "
        f"NVL {nvlink_gbs * approx_factor:3.0f} GB/s | "
        f"{t_fused * 1e6:6.0f} us, "
        f"reduction: {t_reduction * 1e6:5.1f} us"
    )
    if ep_buffer is not None:
        speedup = safe_div(t_baseline, t_fused)
        dist_print(
            f" > [baseline] EP {rank_idx:2}/{num_ranks} | "
            f"{tflops_baseline:4.0f} TFLOPS | "
            f"               HBM {hbm_gbs_baseline:4.0f} GB/s, "
            f"NVL {nvlink_gbs_baseline:3.0f} GB/s | "
            f"{t_baseline * 1e6:6.0f} us | "
            f"t_baseline/t_fused = {speedup:.2f}x "
            f"({'fused 更快' if speedup > 1 else 'baseline 更快'})"
        )
    else:
        dist_print(" > [baseline] (no baseline: deep_ep unavailable)", once_in_node=True)

    # ---- 清理 ----
    dist.barrier()
    sym_buffer.destroy()
    if ep_buffer is not None:
        ep_buffer.destroy()
    dist.destroy_process_group()


# ============================================================================
# 模块 6:argparse + spawn
# ============================================================================

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="H200 mega-MoE: fused (deep_gemm.fp8_mega_moe) vs DeepEP+grouped-FP8 baseline"
    )

    # 资源
    parser.add_argument(
        "--num-processes", type=int, default=8, help="spawn 出来的进程数(一卡一进程)"
    )

    # 模型形状
    # 注:SM90 fused kernel 要求 intermediate_hidden ≤ 4096
    parser.add_argument("--num-max-tokens-per-rank", type=int, default=8192)
    parser.add_argument(
        "--num-tokens",
        type=int,
        default=0,
        help="per-rank 实际 token 数;0 表示用 num-max-tokens-per-rank",
    )
    parser.add_argument("--hidden", type=int, default=7168)
    parser.add_argument(
        "--intermediate-hidden",
        type=int,
        default=3072,
        help="中间层维度(≤ 4096,受 SM90 l2_arrival_mask 约束)",
    )
    parser.add_argument(
        "--activation-clamp",
        type=float,
        default=10.0,
        help="SwiGLU 前对 gate/up 的 clamp 阈值;传 inf 表示关闭",
    )
    parser.add_argument("--num-experts", type=int, default=384)
    parser.add_argument("--num-topk", type=int, default=6)
    parser.add_argument(
        "--fast-math",
        type=int,
        default=1,
        help="fused 内 SwiGLU 是否启用 fast-math(0/1)",
    )

    # 测时
    parser.add_argument(
        "--num-bench-tests",
        type=int,
        default=30,
        help="bench_kineto 抓 fused 时的迭代数",
    )
    parser.add_argument(
        "--num-warmup", type=int, default=5, help="baseline cuda events warmup"
    )
    parser.add_argument(
        "--num-repeat", type=int, default=20, help="baseline cuda events 测时迭代"
    )
    parser.add_argument(
        "--l2-flush-gb",
        type=float,
        default=8.0,
        help="baseline event 测时前用于 flush L2 的临时写入大小;0 表示关闭",
    )
    parser.add_argument(
        "--check-output-diff",
        type=int,
        default=0,
        help="非 0 时打印 fused 与 legacy-per128 baseline 的输出差异(预期非 bitwise)",
    )
    parser.add_argument(
        "--dump-profile-traces",
        type=str,
        default="",
        help="非空时把 fused 的 Chrome trace 写到该目录(每 rank 一份)",
    )

    args = parser.parse_args()

    if args.dump_profile_traces:
        os.makedirs(args.dump_profile_traces, exist_ok=True)

    # 多进程启动:每个进程对应一个 GPU;test() 内部用 init_dist 建 NCCL group
    torch.multiprocessing.spawn(
        test, args=(args.num_processes, args), nprocs=args.num_processes
    )

deepseek-ai/DeepEP#629 没有RDMA 的8卡需要依赖这个PR 才能跑ElasticBuffer接口,另外EP4 是不是太小了?整体应该还是bound 在HBM 读取上面了,看不到megamoe 的收益。

@qiushixiaoyu
Copy link
Copy Markdown
Author

qiushixiaoyu commented May 18, 2026

export PYTHONPATH=/workspace/DeepGEMM:/workspace/DeepEP:${PYTHONPATH:-}
export LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/tvm_ffi/lib:${LD_LIBRARY_PATH:-}
python3 tests/test_mega_moe_hopper.py
--num-processes 8
--num-max-tokens-per-rank
--num-tokens
--hidden 4096
--intermediate-hidden 2048
--num-experts 256
--num-topk 6
--num-bench-tests 5
--num-warmup 2
--num-repeat 5
--l2-flush-gb 0
--run-baseline

Batch Fused avg us Baseline avg us Baseline / Fused Fused TFLOPS Baseline TFLOPS Fused HBM GB/s Baseline HBM GB/s Status
1 183.4 327.6 1.787 1.6 1.0 755.1 422.8 ok
2 263.0 380.4 1.446 2.1 1.5 1005.5 695.6 ok
4 406.1 497.4 1.225 3.0 2.4 1070.5 873.6 ok
8 497.1 546.1 1.099 4.8 4.5 1293.1 1177.2 ok
16 566.0 641.2 1.133 8.4 7.4 1376.8 1214.6 ok
32 576.0 651.0 1.130 16.8 14.8 1404.6 1242.4 ok
64 592.5 653.2 1.103 32.8 29.6 1371.9 1242.5 ok
128 597.9 680.1 1.138 64.9 56.9 1370.9 1202.9 ok
512 1144.0 1220.9 1.067 135.9 126.6 752.1 702.0 ok
1024 1989.5 2189.1 1.100 156.0 141.1 458.8 415.0 ok
4096 6949.8 6913.9 0.995 179.0 179.0 176.0 176.0 ok
8192 13514.9 13343.6 0.987 184.2 185.4 121.2 122.2 ok

@Stone749990226
Copy link
Copy Markdown

你跑过B300的ncu报告吗?我跑出来的B300的报告SM和Memory利用率非常低,不知道是不是跑错了,感觉有点奇怪呢。#336

@Stone749990226
Copy link
Copy Markdown

你跑过B300的ncu报告吗?我跑出来的B300的报告SM和Memory利用率非常低,不知道是不是跑错了,感觉有点奇怪呢。#336

就是官方的Mega MoE的kernel的B300的报告,我在8卡上跑的

@qiushixiaoyu
Copy link
Copy Markdown
Author

@Stone749990226 我只有H20环境

@qinqinwo
Copy link
Copy Markdown

export PYTHONPATH=/workspace/DeepGEMM:/workspace/DeepEP:${PYTHONPATH:-} export LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/tvm_ffi/lib:${LD_LIBRARY_PATH:-} python3 tests/test_mega_moe_hopper.py --num-processes 8 --num-max-tokens-per-rank --num-tokens --hidden 4096 --intermediate-hidden 2048 --num-experts 256 --num-topk 6 --num-bench-tests 5 --num-warmup 2 --num-repeat 5 --l2-flush-gb 0 --run-baseline

Batch Fused avg us Baseline avg us Baseline / Fused Fused TFLOPS Baseline TFLOPS Fused HBM GB/s Baseline HBM GB/s Status
1 183.4 327.6 1.787 1.6 1.0 755.1 422.8 ok
2 263.0 380.4 1.446 2.1 1.5 1005.5 695.6 ok
4 406.1 497.4 1.225 3.0 2.4 1070.5 873.6 ok
8 497.1 546.1 1.099 4.8 4.5 1293.1 1177.2 ok
16 566.0 641.2 1.133 8.4 7.4 1376.8 1214.6 ok
32 576.0 651.0 1.130 16.8 14.8 1404.6 1242.4 ok
64 592.5 653.2 1.103 32.8 29.6 1371.9 1242.5 ok
128 597.9 680.1 1.138 64.9 56.9 1370.9 1202.9 ok
512 1144.0 1220.9 1.067 135.9 126.6 752.1 702.0 ok
1024 1989.5 2189.1 1.100 156.0 141.1 458.8 415.0 ok
4096 6949.8 6913.9 0.995 179.0 179.0 176.0 176.0 ok
8192 13514.9 13343.6 0.987 184.2 185.4 121.2 122.2 ok

请问,这是用最新代码跑的吗?baseline是和sm100上一样,deepep v2+deepgemm完全没任何overlap的吗?我之前跑出来的结果如下:

Model Tokens/rank Backend Time us TFLOPS HBM GB/s NVL GB/s Speedup PR316 time us Time ratio Status
flash 1 pr323-sm90-fp8 360 1 490 0   56.5 6.372 ok
flash 512 pr323-sm90-fp8 2027 76 423 19   146.5 13.836 ok
flash 8192 pr323-sm90-fp8 22957 107 70 26   1283.1 17.892 ok
flash 32768 pr323-sm90-fp8 88802 111 46 27   4855.5 18.289 ok
pro 1 pr323-sm90-fp8 602 1 659 0   108.1 5.569 ok
pro 512 pr323-sm90-fp8 4203 97 776 16   369.6 11.372 ok
pro 8192 pr323-sm90-fp8 58415 111 78 18   2818.5 20.726 ok
pro 32768 pr323-sm90-fp8 220394 118 39 19   10655.2 20.684 ok

@usernamehaha2022
Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap:
image
会比baseline差。
我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果:
--num-processes 4 --num-experts 128 --hidden 3072 on 4*H800
image
看上去性能合理

@usernamehaha2022
Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

test_mega_moe_sm90.py
对应的测试在这个文件。如果有人对我们的优化感兴趣可以一起讨论🤔

@leiyin22
Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

test_mega_moe_sm90.py 对应的测试在这个文件。如果有人对我们的优化感兴趣可以一起讨论🤔

源码有吗?

@foobar2023xx
Copy link
Copy Markdown

在我们的测试里,nvcc编译出来这个版本的寄存器spill很多,对比deepepv2+deepgemm non-overlap: image 会比baseline差。 我们尝试手动解决了一下,并对计算warpgroup做了修改,拿4卡H800测试的效果: --num-processes 4 --num-experts 128 --hidden 3072 on 4*H800 image 看上去性能合理

我在H800上测试的时候,没有观察到明显的寄存器spill,请问你们是基于当前最新版本测试的吗?

Running NVCC command: cd /tmp/dg_sm90_spill_v1/tmp && /usr/local/cuda/bin/nvcc /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cu -cubin -o /tmp/dg_sm90_spill_v1/tmp/13989-a1ffc508-d2e5b710-f4640373/kernel.cubin -std=c++20 --diag-suppress=39,161,174,177,186,940 --ptxas-options=--register-usage-level=10 --ptxas-options=--verbose,--warn-on-local-memory-usage -I/workspace/DeepGEMM/deep_gemm/include --gpu-architecture=sm_90a --compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi -O3 --expt-relaxed-constexpr --expt-extended-lambda
ptxas info    : (C7510) Potential Performance Loss: wgmma.mma_async instructions are serialized due to wgmma pipeline crossing function boundary at a function call in the function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_'
ptxas warning : Local memory used for function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_', size of stack frame: 56 bytes
ptxas info    : 474 bytes gmem
ptxas info    : Compiling entry function '_ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_' for 'sm_90a'
ptxas info    : Function properties for _ZN9deep_gemm22sm90_fp8_mega_moe_implILj8448ELj4096ELj2048ELj256ELj6ELj64ELj64ELj128ELj128ELj215040ELj3440640ELj7ELj0ELj128ELj128ELj128ELj132ELj4ELf41200000ELb1ELj1ELj48ELj40ELj208ELj0ELj4096ELj4096ELj4096ELj2048ELj4ELj4ELj4ELj1ELj384ELj5ELj64EEEvPvPijNS_6layout9SymBufferIXT16_EEE14CUtensorMap_stS6_S6_PKfS6_S6_S6_S6_S8_
    56 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 168 registers, used 16 barriers, 56 bytes cumulative stack size
ptxas info    : Compile time = 582.487 ms

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants