@@ -159,22 +159,23 @@ CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.
159159
160160### Individual bfloat16 torch._ grouped_mm op vs torchao_scaled_grouped_mm
161161
162- MXFP8:
163-
164- | M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup |
165- | ------------------------| -----------------| -------------------| ------------------------|
166- | (128000, 8192, 5120, 1) | 40463 | 24406 | 1.658x |
167- | (128000, 8192, 5120, 2) | 35494.5 | 24705.1 | 1.437x |
168- | (128000, 8192, 5120, 4) | 38879.3 | 24508.5 | 1.586x |
169- | (128000, 8192, 5120, 8) | 35714.6 | 25937.6 | 1.377x |
170- | (128000, 1536, 5120, 1) | 6353.06 | 7401.54 | 0.858x |
171- | (128000, 1536, 5120, 2) | 6511.65 | 6729.33 | 0.968x |
172- | (128000, 1536, 5120, 4) | 6455.2 | 6626.5 | 0.974x |
173- | (128000, 1536, 5120, 8) | 7716.13 | 6516.74 | 1.184x |
174- | (128000, 2048, 7168, 1) | 11758 | 11255.7 | 1.045x |
175- | (128000, 2048, 7168, 2) | 15012.9 | 9917.9 | 1.514x |
176- | (128000, 2048, 7168, 4) | 14904.2 | 10493.8 | 1.42x |
177- | (128000, 2048, 7168, 8) | 13178 | 9638.38 | 1.367x |
162+ ** MXFP8 with Llama4 17b 16e shapes** (with G=1-8 to simulate different degrees of expert parallelism)
163+
164+ | M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup |
165+ | ----------------------- | --------------: | ----------------: | ---------------------: |
166+ | (128000, 8192, 5120, 1) | 43140.20 | 23867.00 | 1.808x |
167+ | (128000, 8192, 5120, 2) | 39487.60 | 23359.00 | 1.690x |
168+ | (128000, 8192, 5120, 4) | 39189.20 | 23945.50 | 1.637x |
169+ | (128000, 8192, 5120, 8) | 37700.70 | 22170.60 | 1.700x |
170+
171+ ** MXFP8 with DeepSeekV3** (with G=-8 to simulate different degrees of expert parallelism)
172+
173+ | M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup |
174+ | ----------------------- | --------------: | ----------------: | ---------------------: |
175+ | (128000, 2048, 7168, 1) | 13064.80 | 10996.00 | 1.188x |
176+ | (128000, 2048, 7168, 2) | 14900.20 | 11283.40 | 1.321x |
177+ | (128000, 2048, 7168, 4) | 15823.60 | 9919.36 | 1.595x |
178+ | (128000, 2048, 7168, 8) | 14966.80 | 10397.20 | 1.440x |
178179
179180
180181To reproduce this benchmark, on a B200 GPU machine, run the following command:
0 commit comments