Skip to content

[Perf] SM100 FP8 GEMM Optimizations after cutlass_profiler #20071

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ilmarkov
Copy link
Contributor

@ilmarkov ilmarkov commented Jun 25, 2025

Additional performance optimizations after #19566

Tune CUTLASS configs for M <= 256 using cutlass_profiler insights.
The PR allows to get from 1.12x to 1.5x speedup for certain NxK pairs without affecting performance for the other NxK.

Cutlass profiler results.

We pick the best configuration that has best performance across different NxK pairs for static M.
cta_m, cta_n, cta_k, cluster_m, cluster_n, cluster_k are the parameters. Rank represents the rank of the performance parameters set for a single NxK in cutlass_profiler comparison, GFPLOPs - performance results in the corresponding benchmark for all NxK pairs.
M=16

cta_m cta_n cta_k cluster_m cluster_n cluster_k Rank GFLOPs
64 64 128 1 4 1 ['16x2560x8192:3', '16x28672x4096:1', '16x4096x14336:5', '16x14336x8192:8', '16x4096x4096:8', '16x8192x7168:2', '16x6144x4096:2', '16x8192x2048:1'] [54489.5, 146659.0, 101793.0, 139143.0, 52288.0, 114496.0, 77335.6, 54581.4]
128 128 128 2 1 1 ['16x2560x8192:6', '16x28672x4096:11', '16x4096x14336:4', '16x14336x8192:6', '16x4096x4096:7', '16x8192x7168:5', '16x6144x4096:8', '16x8192x2048:4'] [54301.8, 133538.0, 101793.0, 140305.0, 52292.9, 111033.0, 66998.4, 52344.8]
64 64 128 1 2 1 ['16x2560x8192:2', '16x28672x4096:9', '16x4096x14336:3', '16x14336x8192:15', '16x4096x4096:1', '16x8192x7168:3', '16x6144x4096:5', '16x8192x2048:8'] [54492.3, 135616.0, 101841.0, 136116.0, 52382.7, 113502.0, 74268.9, 52338.3]
64 64 256 1 4 1 ['16x2560x8192:14', '16x28672x4096:4', '16x4096x14336:13', '16x14336x8192:10', '16x4096x4096:5', '16x8192x7168:1', '16x6144x4096:3', '16x8192x2048:2'] [50150.8, 142839.0, 99246.8, 137627.0, 52317.3, 114500.0, 76734.4, 54349.9]
64 64 128 1 8 1 ['16x2560x8192:5', '16x28672x4096:3', '16x4096x14336:6', '16x14336x8192:9', '16x4096x4096:2', '16x8192x7168:13', '16x6144x4096:1', '16x8192x2048:14'] [54485.2, 145789.0, 101781.0, 137939.0, 52336.9, 96850.9, 77881.2, 52295.9]
64 64 128 1 1 1 ['16x2560x8192:1', '16x28672x4096:15', '16x4096x14336:1', '16x14336x8192:4', '16x4096x4096:10', '16x8192x7168:11', '16x6144x4096:10', '16x8192x2048:6'] [54492.3, 125416.0, 104745.0, 141677.0, 52281.5, 101680.0, 65444.6, 52341.5]
128 128 128 2 2 1 ['16x2560x8192:9', '16x28672x4096:13', '16x4096x14336:9', '16x14336x8192:11', '16x4096x4096:4', '16x8192x7168:10', '16x6144x4096:9', '16x8192x2048:10'] [52898.1, 130875.0, 101497.0, 137532.0, 52330.4, 101851.0, 65446.3, 52310.5]

M=32

cta_m cta_n cta_k cluster_m cluster_n cluster_k Rank GFLOPs
64 64 128 1 1 1 ['32x2560x8192:1', '32x4096x4096:1', '32x14336x8192:4', '32x8192x7168:1', '32x4096x14336:1', '32x8192x2048:2', '32x28672x4096:2', '32x6144x4096:1'] [109044.0, 104880.0, 281366.0, 255430.0, 205877.0, 106808.0, 297854.0, 157006.0]
64 64 128 1 2 1 ['32x2560x8192:3', '32x4096x4096:2', '32x14336x8192:3', '32x8192x7168:4', '32x4096x14336:5', '32x8192x2048:8', '32x28672x4096:6', '32x6144x4096:3'] [108976.0, 104680.0, 281896.0, 232749.0, 203572.0, 104680.0, 287400.0, 156679.0]
64 64 128 1 4 1 ['32x2560x8192:4', '32x4096x4096:9', '32x14336x8192:9', '32x8192x7168:2', '32x4096x14336:3', '32x8192x2048:1', '32x28672x4096:3', '32x6144x4096:2'] [108973.0, 104583.0, 277114.0, 235758.0, 203604.0, 109365.0, 297120.0, 156849.0]
64 64 256 1 1 1 ['32x2560x8192:7', '32x4096x4096:8', '32x14336x8192:8', '32x8192x7168:3', '32x4096x14336:7', '32x8192x2048:4', '32x28672x4096:13', '32x6144x4096:5'] [107446.0, 104651.0, 277121.0, 234253.0, 202564.0, 104794.0, 282269.0, 155603.0]
128 128 128 2 1 1 ['32x2560x8192:5', '32x4096x4096:6', '32x14336x8192:7', '32x8192x7168:7', '32x4096x14336:6', '32x8192x2048:5', '32x28672x4096:12', '32x6144x4096:11'] [108418.0, 104661.0, 279744.0, 229269.0, 203565.0, 104784.0, 282374.0, 151110.0]
64 64 256 1 4 1 ['32x2560x8192:14', '32x4096x4096:10', '32x14336x8192:11', '32x8192x7168:5', '32x4096x14336:13', '32x8192x2048:3', '32x28672x4096:9', '32x6144x4096:6'] [98946.0, 104579.0, 274213.0, 229839.0, 194072.0, 105838.0, 283253.0, 154406.0]
128 128 128 2 2 1 ['32x2560x8192:9', '32x4096x4096:13', '32x14336x8192:14', '32x8192x7168:8', '32x4096x14336:8', '32x8192x2048:11', '32x28672x4096:8', '32x6144x4096:10'] [104131.0, 104573.0, 272192.0, 228992.0, 201892.0, 104589.0, 284123.0, 151570.0]

M=64

cta_m cta_n cta_k cluster_m cluster_n cluster_k Rank GFLOPs
64 64 128 1 1 1 ['64x28672x4096:5', '64x14336x8192:2', '64x2560x8192:1', '64x4096x4096:1', '64x6144x4096:1', '64x8192x7168:1', '64x4096x14336:1', '64x8192x2048:1'] [570911.0, 563819.0, 218094.0, 209544.0, 313699.0, 495003.0, 407603.0, 260269.0]
64 64 128 1 4 1 ['64x28672x4096:2', '64x14336x8192:3', '64x2560x8192:4', '64x4096x4096:7', '64x6144x4096:2', '64x8192x7168:2', '64x4096x14336:4', '64x8192x2048:2'] [581315.0, 561891.0, 217901.0, 209328.0, 313679.0, 468488.0, 407144.0, 241908.0]
64 64 128 1 2 1 ['64x28672x4096:8', '64x14336x8192:7', '64x2560x8192:2', '64x4096x4096:2', '64x6144x4096:4', '64x8192x7168:3', '64x4096x14336:3', '64x8192x2048:4'] [564532.0, 557186.0, 217975.0, 209518.0, 313084.0, 466719.0, 407186.0, 228778.0]
64 64 256 1 1 1 ['64x28672x4096:10', '64x14336x8192:8', '64x2560x8192:6', '64x4096x4096:4', '64x6144x4096:5', '64x8192x7168:4', '64x4096x14336:7', '64x8192x2048:5'] [563699.0, 548291.0, 213515.0, 209335.0, 311197.0, 465213.0, 403999.0, 227815.0]
64 64 256 1 2 1 ['64x28672x4096:11', '64x14336x8192:13', '64x2560x8192:8', '64x4096x4096:5', '64x6144x4096:6', '64x8192x7168:5', '64x4096x14336:9', '64x8192x2048:3'] [563246.0, 543370.0, 205306.0, 209328.0, 308140.0, 459435.0, 394017.0, 229262.0]
128 128 128 2 1 1 ['64x28672x4096:14', '64x14336x8192:9', '64x2560x8192:7', '64x4096x4096:6', '64x6144x4096:9', '64x8192x7168:7', '64x4096x14336:5', '64x8192x2048:7'] [562585.0, 547882.0, 209893.0, 209328.0, 301623.0, 458252.0, 405590.0, 217406.0]
128 128 128 2 2 1 ['64x28672x4096:7', '64x14336x8192:10', '64x2560x8192:10', '64x4096x4096:9', '64x6144x4096:10', '64x8192x7168:8', '64x4096x14336:8', '64x8192x2048:8'] [564539.0, 547824.0, 201562.0, 209322.0, 297717.0, 457984.0, 398070.0, 212527.0]

M=128

cta_m cta_n cta_k cluster_m cluster_n cluster_k Rank GFLOPs
128 128 128 2 1 1 ['128x2560x8192:8', '128x4096x4096:9', '128x14336x8192:4', '128x8192x7168:1', '128x8192x2048:1', '128x28672x4096:3', '128x6144x4096:2', '128x4096x14336:5'] [415092.0, 418265.0, 1065420.0, 916540.0, 426784.0, 1097860.0, 577819.0, 801383.0]
128 128 128 2 2 1 ['128x2560x8192:10', '128x4096x4096:8', '128x14336x8192:2', '128x8192x7168:2', '128x8192x2048:2', '128x28672x4096:2', '128x6144x4096:1', '128x4096x14336:6'] [390743.0, 418278.0, 1081600.0, 915932.0, 420450.0, 1126510.0, 582752.0, 770648.0]
64 64 128 1 1 1 ['128x2560x8192:1', '128x4096x4096:1', '128x14336x8192:13', '128x8192x7168:5', '128x8192x2048:5', '128x28672x4096:8', '128x6144x4096:5', '128x4096x14336:1'] [435950.0, 418656.0, 907425.0, 814372.0, 418785.0, 1001930.0, 523095.0, 814753.0]
64 64 128 2 1 1 ['128x2560x8192:3', '128x4096x4096:7', '128x14336x8192:15', '128x8192x7168:9', '128x8192x2048:3', '128x28672x4096:9', '128x6144x4096:4', '128x4096x14336:2'] [433685.0, 418291.0, 891168.0, 814302.0, 418837.0, 978594.0, 523163.0, 814400.0]
64 64 128 1 2 1 ['128x2560x8192:2', '128x4096x4096:10', '128x14336x8192:12', '128x8192x7168:8', '128x8192x2048:4', '128x28672x4096:10', '128x6144x4096:8', '128x4096x14336:3'] [435825.0, 418265.0, 909419.0, 814330.0, 418811.0, 977952.0, 520539.0, 814287.0]

Kernel benchmarks using #17126 on B200.
python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
and python benchmarks/kernels/bench_fp8_gemm.py --model meta-llama/Llama-3.3-70B-Instruct --tp-sizes 4

meta-llama/Llama-3.1-8B-Instruct

N=4096 K=14336

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 10.22 10.08 0.99
16 162.61 161.47 0.99
32 323.89 321.19 0.99
64 642.23 637.38 0.99
128 1225.21 1234.45 1.01
192 1597.06 1649.77 1.03
256 2059.40 2242.74 1.09

meta-llama/Llama-3.1-8B-Instruct

N=4096 K=4096

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 5.86 5.69 0.97
16 93.22 91.95 0.99
32 185.81 188.96 1.02
64 367.71 373.51 1.02
128 703.56 708.25 1.01
192 925.66 989.09 1.07
256 1246.88 1335.40 1.07

meta-llama/Llama-3.1-8B-Instruct

N=6144 K=4096

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 8.41 7.86 0.93
16 137.74 128.51 0.93
32 271.65 278.66 1.03
64 533.76 553.45 1.04
128 1019.67 1041.53 1.02
192 1019.66 1024.30 1.00
256 1373.78 1354.85 0.99

meta-llama/Llama-3.1-8B-Instruct

N=28672 K=4096

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 10.10 10.62 1.05
16 159.61 168.90 1.06
32 314.99 332.05 1.05
64 619.51 648.96 1.05
128 1182.28 1192.15 1.01
192 1332.86 1320.79 0.99
256 1663.22 1630.00 0.98

meta-llama/Llama-3.3-70B-Instruct

N=14336 K=8192

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 9.91 9.41 0.95
16 157.88 166.88 1.06
32 313.35 333.66 1.06
64 618.64 622.06 1.01
128 1170.04 1195.33 1.02
192 1280.06 1238.12 0.97
256 1610.38 1620.58 1.01

meta-llama/Llama-3.3-70B-Instruct

N=2560 K=8192

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 5.27 5.21 0.99
16 83.95 83.13 0.99
32 167.41 165.62 0.99
64 332.07 328.48 0.99
128 635.35 634.71 1.00
192 938.49 935.84 1.00
256 1247.44 1245.94 1.00

meta-llama/Llama-3.3-70B-Instruct

N=8192 K=2048

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 5.65 6.33 1.12
16 90.38 106.86 1.18
32 178.05 195.24 1.10
64 345.92 467.90 1.35
128 654.82 861.29 1.31
192 943.38 935.16 0.99
256 1238.14 1250.33 1.01

meta-llama/Llama-3.3-70B-Instruct

N=8192 K=7168

batch_size fp8-tensor-w-tensor-a-noquant (Before) fp8-tensor-w-tensor-a-noquant (After) speedup
1 9.84 12.38 1.26
16 156.65 199.27 1.27
32 309.39 451.10 1.46
64 608.19 938.37 1.54
128 1158.40 1777.43 1.53
192 1502.84 1582.51 1.05
256 2021.96 2042.31 1.01

Raw results:

Before:
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:                                                           
BF16 vs FP8 GEMMs:                                                                                                                         
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant     
0         1.0     5.741441               5.889322               5.363307                       8.414576                       8.380952
1        16.0    94.679950              93.698559              85.468221                     137.739731                     137.578115
2        32.0   183.707332             189.485936             172.590395                     271.646393                     271.306649
3        64.0   403.369655             373.401188             340.459855                     533.758398                     535.405370
4       128.0   729.183154             713.216098             638.050011                    1019.671603                    1020.946228
5       192.0   767.153850             762.612420             711.318064                    1019.661534                    1029.178464
6       256.0  1095.091461            1009.552856             942.036030                    1373.779732                    1367.500228
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:                                                                
BF16 vs FP8 GEMMs:                                                                                                                         
   batch_size  torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant 
0         1.0    2.523119               3.974651               3.627800                       5.859361                       5.854528 
1        16.0   72.806091              63.548352              58.053531                      93.220465                      93.379579 
2        32.0  144.698420             128.036377             116.704942                     185.809005                     185.671120 
3        64.0  286.282226             251.610524             229.856624                     367.708809                     367.587295 
4       128.0  563.190528             489.574274             442.762365                     703.557565                     700.682444
5       192.0  791.232107             642.376266             578.835411                     925.657288                     954.357496      
6       256.0  888.059340             844.223230             771.735477                    1246.879553                    1249.429003
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:                                                         
BF16 vs FP8 GEMMs:                                                                                                                         
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     5.894464               8.858697               8.630923                      10.100937                      10.063045
1        16.0    81.532051             139.532974             136.451239                     159.605105                     159.714319
2        32.0   160.943398             275.949885             269.435411                     314.992892                     314.751329     
3        64.0   319.878854             543.014128             529.876539                     619.514309                     618.966013 
4       128.0   613.371075            1030.604967            1012.227544                    1182.276151                    1177.310747
5       192.0   837.895127            1198.302184            1236.082999                    1332.857494                    1362.815805
6       256.0  1037.993814            1402.775098            1485.557965                    1663.216280                    1495.555091     
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:                                                         
BF16 vs FP8 GEMMs:                                                                                                                         
   batch_size  torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0    4.905350               6.949511               6.300312                      10.215670                      10.216556
1        16.0   80.882981             111.501007             101.083220                     162.606038                     162.555430
2        32.0  141.641973             221.826193             201.151377                     323.891604                     323.966491
3        64.0  305.293094             438.228166             397.044911                     642.230583                     642.409625
4       128.0  594.671522             782.029357             763.854150                    1225.205369                    1228.060189      
5       192.0  641.518559            1094.145217             948.899860                    1597.055535                    1499.209977 
6       256.0  921.058567             956.245340            1316.473683                    2059.402628                    2055.774977 

meta-llama/Llama-3.3-70B-Instruct, N=2560 K=8192, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size  torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0    2.247150               3.688224               3.340765                       5.269016                       5.263209
1        16.0   56.084784              58.365381              52.911891                      83.947982                      83.922893
2        32.0  111.571267             116.086121             105.237080                     167.412861                     167.339616
3        64.0  218.608519             229.665938             207.591697                     332.072649                     331.695491
4       128.0  426.064208             442.797352             399.597632                     635.349616                     633.780546
5       192.0  648.557400             623.430825             559.161170                     938.490452                     935.055219
6       256.0  809.177512             826.646153             744.787525                    1247.438434                    1243.785137
meta-llama/Llama-3.3-70B-Instruct, N=8192 K=2048, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     4.886522               4.009507               3.664337                       5.650769                       5.620400
1        16.0    92.691298              63.509525              58.944134                      90.376268                      90.132737
2        32.0   183.956786             127.050793             117.698866                     178.046352                     178.503483
3        64.0   357.481226             249.046628             231.077185                     345.924649                     347.610884
4       128.0   654.879700             466.611188             424.118639                     654.817899                     647.658903
5       192.0   896.451572             666.452690             600.419689                     943.381190                     915.537523
6       256.0  1058.266034             872.727923             794.987820                    1238.137168                    1222.881190
meta-llama/Llama-3.3-70B-Instruct, N=14336 K=8192, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     5.767851               8.269581               8.117290                       9.909120                       9.910631
1        16.0    95.268574             133.095823             129.683937                     157.880487                     157.992440
2        32.0   193.802700             262.948193             256.521498                     313.348997                     313.429716
3        64.0   365.626929             515.916613             505.658013                     618.635402                     619.005317
4       128.0   715.076140             959.378082             975.503983                    1170.040199                    1190.521965
5       192.0   902.653741            1091.773690            1110.502427                    1280.060404                    1311.778132
6       256.0  1108.966212            1529.349831            1464.603339                    1610.377929                    1512.480083
meta-llama/Llama-3.3-70B-Instruct, N=8192 K=7168, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
   batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0         1.0     5.413602               7.548899               7.047636                       9.838443                       9.825678
1        16.0    86.738574             119.896628             111.867585                     156.650353                     156.586286
2        32.0   170.657537             241.528158             224.942129                     309.385464                     309.191719
3        64.0   329.848545             475.250413             442.987203                     608.188558                     606.203303
4       128.0   663.320465             901.335935             831.923547                    1158.397573                    1154.890779
5       192.0   882.370476            1223.573478            1098.360487                    1502.839763                    1472.101452
6       256.0  1066.051424            1532.019376            1424.085128                    2021.955711                    1699.906574

After:
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:                                                            
BF16 vs FP8 GEMMs:                                                                                                                          
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.747114               5.612546               5.141551                       7.864173                       7.794574
1         16.0    94.679661              89.500248              82.045488                     128.511559                     127.272434
2         64.0   404.094024             378.562338             346.038245                     553.453917                     553.673110
3        128.0   729.413173             721.308876             650.725180                    1041.533876                    1038.261908
4        256.0  1139.336957             995.640106             942.699714                    1354.846282                    1349.292699
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:                                                            
BF16 vs FP8 GEMMs:                                                                                                                          
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     2.507184               3.951163               3.621669                       5.693663                       5.721870
1         16.0    72.936451              63.934449              58.323370                      91.950043                      91.929950
2         64.0   286.246926             255.536549             232.899315                     373.507265                     373.717844
3        128.0   562.390855             490.499407             442.432846                     708.249839                     705.500147
4        256.0   863.763778             885.687346             801.106454                    1335.404630                    1332.259734
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:                                                           
BF16 vs FP8 GEMMs:                                                                                                                          
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.807663               8.910787               8.764737                      10.621919                      10.419067
1         16.0    83.888922             149.658547             144.921846                     168.898281                     167.030801
2         64.0   327.389390             564.345330             550.199015                     648.955934                     642.618156
3        128.0   628.627237            1056.895933            1034.984411                    1192.150441                    1198.088280
4        256.0  1053.583061            1574.427693            1623.290825                    1629.970175                    1512.843407
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:                                                           
BF16 vs FP8 GEMMs:                                                                                                                          
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     4.953190               6.975380               6.305357                      10.080636                      10.102094
1         16.0    80.948834             110.746501             100.482041                     161.474463                     161.378571
2         64.0   303.193123             433.296150             395.757420                     637.383447                     636.847998
3        128.0   593.497724             831.819854             762.548069                    1234.452258                    1234.169426
4        256.0   902.580858            1406.340148            1350.980816                    2242.739002                    2092.167812

meta-llama/Llama-3.3-70B-Instruct, N=2560 K=8192, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     2.247340               3.659181               3.329538                       5.212307                       5.210786
1         16.0    56.081160              58.202712              52.856770                      83.131943                      83.111734
2         64.0   218.590131             229.122731             208.869113                     328.475612                     328.473047
3        128.0   426.119325             443.693366             401.984038                     634.710338                     632.658806
4        256.0   809.296378             823.096870             741.983692                    1245.943179                    1242.863817
meta-llama/Llama-3.3-70B-Instruct, N=8192 K=2048, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     4.845773               4.290688               3.898556                       6.333271                       6.295150
1         16.0    92.665725              70.383766              64.498301                     106.858177                     105.758622
2         64.0   357.499226             294.016501             267.967038                     467.897992                     467.926530
3        128.0   654.833743             557.441637             493.817842                     861.290769                     858.600960
4        256.0  1055.796509             884.438059             804.766163                    1250.331980                    1231.177675
meta-llama/Llama-3.3-70B-Instruct, N=14336 K=8192, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     6.013571               7.794183               7.684053                       9.408748                       9.511499
1         16.0    95.663708             138.321978             134.952731                     166.878534                     165.761780
2         64.0   368.880370             525.963146             513.749759                     622.059570                     624.920543
3        128.0   716.004597             956.041563             975.715516                    1195.333181                    1187.001804
4        256.0  1113.386202            1378.281148            1418.003490                    1620.575688                    1532.953099
meta-llama/Llama-3.3-70B-Instruct, N=8192 K=7168, BF16 vs FP8 GEMMs TFLOP/s:
BF16 vs FP8 GEMMs:
    batch_size   torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0     5.322447               8.721340               8.076960                      12.377826                      11.952172
1         16.0    87.329907             145.841938             134.309687                     199.273437                     194.992624
2         64.0   327.285311             638.025070             583.029556                     938.374708                     928.551520
3        128.0   683.424825            1090.360496            1093.481578                    1777.431884                    1766.970447
4        256.0  1005.921614            1570.834976            1508.185079                    2042.310073                    1974.284171

Signed-off-by: ilmarkov <imarkov@redhat.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ilmarkov, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request implements performance optimizations for FP8 General Matrix Multiply (GEMM) operations by refining the underlying CUTLASS configurations. The changes focus on improving efficiency for various matrix dimensions, especially for smaller M values (up to 256), based on extensive profiling data. The goal is to enhance throughput and reduce computation time for relevant workloads.

Highlights

  • CUTLASS Configuration Tuning: Optimized CUTLASS GEMM configurations (TileShape and ClusterShape) for FP8 operations on SM100 architecture, specifically targeting M dimensions up to 256. These adjustments are based on insights from cutlass_profiler.
  • Refined M-Dimension Dispatch Logic: The dispatch mechanism for FP8 GEMM kernels has been updated to apply the newly tuned configurations more precisely across different M (batch size) ranges. This includes adjusting the M thresholds and corresponding configuration struct names.
  • Performance Improvements: The changes are intended to yield performance gains, particularly for smaller batch sizes, as demonstrated by the provided benchmark results showing speedups up to 1.54x for certain configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes tune CUTLASS configurations for M <= 256 based on cutlass_profiler insights. The changes involve adjusting the tile shape and cluster shape for different ranges of M, and updating the dispatch logic to use the new configurations. The review focuses on ensuring that the comments accurately reflect the new ranges for M, and that the changes align with the performance tuning results.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM considering the large improvements for Llama 70B, thanks!

@mgoin mgoin enabled auto-merge (squash) June 25, 2025 21:36
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 25, 2025
@mgoin mgoin changed the title FP8 gemm Optimizations after cutlass_profiler [Perf] SM100 FP8 GEMM Optimizations after cutlass_profiler Jun 26, 2025
@mgoin mgoin added the performance Performance-related issues label Jun 26, 2025
@vllm-bot vllm-bot merged commit 2d7779f into vllm-project:main Jun 27, 2025
109 of 113 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants