Skip to content

Conversation

@elvischenv
Copy link
Contributor

@elvischenv elvischenv commented Aug 26, 2025

Purpose

Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel.
After this PR, Flashinfer + kv_cache_dtype=fp8 will always quantize query to fp8 and use TRTLLM attn kernel(support FP8-qkv BF16/FP16/FP8/NVFP4-out).
Note: This requires Flashinfer 0.3.0 to land

Test Plan && Test Result

Kernel functional:
tests/kernels/attention/test_flashinfer_trtllm_attention.py

===== 112 passed, 8 skipped in 15.66s =====

Kernel performance:
Decode: benchmarks/kernels/benchmark_trtllm_decode_attention.py

Running benchmark for q_dtype = torch.float8_e4m3fn, kv_cache_dtype: torch.float8_e4m3fn, output_dtype: torch.bfloat16
batch_size      max_seq_len     trtllm_mean     trtllm_std      baseline_mean   baseline_std    speedup_percent
        1       1024    0.041   0.003   0.053   0.004   0.222
        4       1024    0.041   0.002   0.051   0.002   0.197
        8       1024    0.042   0.002   0.052   0.003   0.189
        16      1024    0.045   0.003   0.055   0.002   0.190
        32      1024    0.045   0.002   0.060   0.002   0.240
        64      1024    0.051   0.002   0.072   0.003   0.291
        128     1024    0.060   0.005   0.091   0.007   0.339
        256     1024    0.079   0.007   0.129   0.007   0.391
        1       2048    0.041   0.002   0.050   0.002   0.188
        4       2048    0.042   0.002   0.050   0.003   0.160
        8       2048    0.042   0.001   0.051   0.002   0.176
        16      2048    0.045   0.002   0.059   0.002   0.226
        32      2048    0.052   0.002   0.077   0.006   0.328
        64      2048    0.060   0.003   0.094   0.007   0.362
        128     2048    0.079   0.007   0.134   0.007   0.411
        256     2048    0.111   0.007   0.191   0.007   0.419
        1       4096    0.045   0.002   0.052   0.002   0.140
        4       4096    0.043   0.002   0.055   0.002   0.208
        8       4096    0.046   0.002   0.062   0.002   0.258
        16      4096    0.052   0.002   0.072   0.003   0.288
        32      4096    0.065   0.005   0.101   0.007   0.353
        64      4096    0.081   0.006   0.134   0.006   0.401
        128     4096    0.112   0.007   0.203   0.008   0.448
        256     4096    0.181   0.007   0.321   0.007   0.437
        1       8192    0.045   0.002   0.054   0.003   0.164
        4       8192    0.046   0.001   0.059   0.002   0.215
        8       8192    0.054   0.003   0.075   0.003   0.283
        16      8192    0.065   0.007   0.095   0.007   0.323
        32      8192    0.091   0.006   0.150   0.007   0.394
        64      8192    0.120   0.008   0.214   0.008   0.440
        128     8192    0.179   0.006   0.345   0.011   0.482
        256     8192    0.318   0.007   0.589   0.008   0.460
        1       16384   0.047   0.003   0.057   0.002   0.177
        4       16384   0.054   0.003   0.078   0.007   0.314
        8       16384   0.065   0.006   0.091   0.007   0.285
        16      16384   0.088   0.007   0.117   0.007   0.250
        32      16384   0.135   0.007   0.243   0.007   0.443
        64      16384   0.206   0.008   0.409   0.011   0.496
        128     16384   0.338   0.009   0.688   0.013   0.509
        256     16384   0.579   0.008   1.082   0.013   0.465
        1       32768   0.049   0.002   0.069   0.002   0.285
        4       32768   0.066   0.007   0.097   0.007   0.317
        8       32768   0.092   0.008   0.134   0.007   0.312
        16      32768   0.143   0.008   0.218   0.007   0.345
        32      32768   0.253   0.007   0.497   0.008   0.492
        64      32768   0.396   0.009   0.804   0.022   0.507
        128     32768   0.662   0.012   1.349   0.022   0.509
        256     32768   1.077   0.012   2.055   0.019   0.476
        1       65536   0.057   0.003   0.090   0.016   0.368
        4       65536   0.098   0.008   0.159   0.008   0.385
        8       65536   0.144   0.007   0.241   0.008   0.401
        16      65536   0.248   0.007   0.409   0.008   0.395
        32      65536   0.460   0.008   0.921   0.008   0.500
        64      65536   0.731   0.012   1.462   0.024   0.500
        128     65536   1.246   0.018   2.620   0.037   0.525
        256     65536   2.308   0.024   4.373   0.031   0.472
        1       131072  0.073   0.006   0.121   0.007   0.398
        4       131072  0.156   0.007   0.261   0.008   0.404
        8       131072  0.255   0.007   0.379   0.007   0.327
        16      131072  0.455   0.007   0.601   0.008   0.242
        32      131072  0.837   0.032   1.671   0.009   0.499
        64      131072  1.362   0.054   2.897   0.050   0.530
        128     131072  2.398   0.032   5.014   0.044   0.522
        256     131072  4.466   0.037   8.737   0.077   0.489

Prefill: benchmarks/kernels/benchmark_trtllm_prefill_attention.py

Running benchmark for q_dtype = torch.float8_e4m3fn, kv_cache_dtype: torch.float8_e4m3fn, output_dtype: torch.bfloat16
batch_size      max_seq_len     trtllm_mean     trtllm_std      baseline_mean   baseline_std    speedup_percent
        1       2048       0.179           0.006           0.243           0.008           0.263
        4       2048       0.333           0.007           0.399           0.010           0.166
        8       2048       0.603           0.007           0.660           0.011           0.086
        16      2048       0.834           0.007           0.887           0.011           0.059
        32      2048       1.655           0.007           1.706           0.011           0.030
        64      2048       3.222           0.007           3.278           0.012           0.017
        128     2048       5.900           0.008           5.880           0.010          -0.003
        256     2048      11.425           0.023          11.379           0.012          -0.004
        1       4096       0.525           0.006           0.710           0.012           0.260
        4       4096       0.841           0.006           1.091           0.012           0.229
        8       4096       1.738           0.007           2.122           0.017           0.181
        16      4096       3.976           0.007           4.871           0.029           0.184
        32      4096       6.341           0.018           7.526           0.017           0.158
        64      4096      12.286           0.008          14.695           0.032           0.164
        128     4096      19.099           0.013          22.307           0.021           0.144
        256     4096      38.410           0.021          44.985           0.028           0.146
        1       8192       1.751           0.007           2.444           0.012           0.284
        4       8192       4.346           0.007           5.888           0.026           0.262
        8       8192       4.999           0.007           6.564           0.016           0.238
        16      8192       7.491           0.006           9.675           0.030           0.226
        32      8192      16.533           0.008          21.231           0.010           0.221
        64      8192      31.066           0.021          39.888           0.033           0.221
        128     8192      65.677           0.029          84.116           0.020           0.219
        256     8192     131.214           0.012         168.143           0.039           0.220
        1       16384      6.420           0.006           9.001           0.019           0.287
        4       16384     14.249           0.008          19.525           0.015           0.270
        8       16384     20.049           0.008          27.262           0.017           0.265
        16      16384     26.334           0.006          35.547           0.014           0.259
        32      16384     63.440           0.012          86.571           0.015           0.267
        64      16384    116.734           0.026         158.061           0.042           0.261
        128     16384    230.956           0.043         312.395           0.040           0.261
        256     16384    498.184           0.043         676.383           0.044           0.263
        1       32768     24.602           0.009          34.960           0.011           0.296
        4       32768     37.232           0.009          52.619           0.018           0.292
        8       32768     56.907           0.014          80.481           0.022           0.293
        16      32768     75.017           0.024         104.545           0.013           0.282
        32      32768    178.917           0.020         249.813           0.029           0.284
        64      32768    412.455           0.030         575.280           0.014           0.283
...

kv_cache_dtype=fp8 + unfused/fused unit test:
tests/compile/test_fusion_attn.py::test_attention_quant_pattern

======= 12 passed, 4 warnings in 30.10s ======

lm_eval
main:
nvidia/Llama-3.3-70B-Instruct-FP8

kv_cache_dtype=auto:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'auto', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.900|±  |0.0134|

kv_cache_dtype=fp8:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.942|±  |0.0105|
|     |       |strict-match    |     5|exact_match|↑  |0.904|±  |0.0132|

kv_cache_dtype=fp8 + enable_attn_fusion:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'compilation_config': {'cudagraph_mode': 'FULL_DECODE_ONLY', 'splitting_ops': [], 'custom_ops': ['+quant_fp8'], 'pass_config': {'enable_attn_fusion': True, 'enable_noop': True}}, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.93|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  | 0.89|±  |0.0140|

nvidia/Llama-3.3-70B-Instruct-FP4

kv_cache_dtype=auto:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP4', 'kv_cache_dtype': 'auto', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.940|±  |0.0106|
|     |       |strict-match    |     5|exact_match|↑  |0.852|±  |0.0159|

kv_cache_dtype=fp8:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP4', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.93|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  | 0.86|±  |0.0155|

kv_cache_dtype=fp8 + enable_attn_fusion:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP4', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'compilation_config': {'cudagraph_mode': 'FULL_DECODE_ONLY', 'splitting_ops': [], 'custom_ops': ['+quant_fp8'], 'pass_config': {'enable_attn_fusion': True, 'enable_noop': True}}, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.934|±  |0.0111|
|     |       |strict-match    |     5|exact_match|↑  |0.820|±  |0.0172|

PR:
nvidia/Llama-3.3-70B-Instruct-FP8

kv_cache_dtype=auto:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'auto', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.904|±  |0.0132|

kv_cache_dtype=fp8:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.904|±  |0.0132|

kv_cache_dtype=fp8 + enable_attn_fusion:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'compilation_config': {'cudagraph_mode': 'FULL_DECODE_ONLY', 'splitting_ops': [], 'custom_ops': ['+quant_fp8'], 'pass_config': {'enable_attn_fusion': True, 'enable_noop': True}}, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.888|±  |0.0141|

nvidia/Llama-3.3-70B-Instruct-FP4

kv_cache_dtype=auto:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP4', 'kv_cache_dtype': 'auto', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.936|±  | 0.011|
|     |       |strict-match    |     5|exact_match|↑  |0.850|±  | 0.016|

kv_cache_dtype=fp8:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP4', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.862|±  |0.0154|

kv_cache_dtype=fp8 + enable_attn_fusion:
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP4', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'compilation_config': {'cudagraph_mode': 'FULL_DECODE_ONLY', 'splitting_ops': [], 'custom_ops': ['+quant_fp8'], 'pass_config': {'enable_attn_fusion': True, 'enable_noop': True}}, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.936|±  | 0.011|
|     |       |strict-match    |     5|exact_match|↑  |0.824|±  | 0.017|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@houseroad
Copy link
Collaborator

Btw, shall we try gpt-oss on gb200? AIME + high reasoning effort is quite useful for the accuracy.

@elvischenv elvischenv force-pushed the elvischenv/trtllm-fp8-q-bf16-out-attn branch from 2e0b199 to 8562b40 Compare September 5, 2025 02:24
@elvischenv elvischenv marked this pull request as ready for review September 5, 2025 02:25
@elvischenv elvischenv force-pushed the elvischenv/trtllm-fp8-q-bf16-out-attn branch from 8562b40 to 5c5647b Compare September 5, 2025 04:05
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, let's wait for FlashInfer version to land

@elvischenv
Copy link
Contributor Author

@ProExpertProg Thanks for the review. #24086 Flashinfer 0.3.0 has been updated to main.

@ProExpertProg ProExpertProg enabled auto-merge (squash) September 5, 2025 15:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 5, 2025
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
auto-merge was automatically disabled September 7, 2025 15:55

Head branch was pushed to by a user without write access

@elvischenv elvischenv force-pushed the elvischenv/trtllm-fp8-q-bf16-out-attn branch from 4deafc8 to d203048 Compare September 7, 2025 15:55
@elvischenv
Copy link
Contributor Author

Still suffering from the tests/kernels/test_cutlass_mla_decode.py failure in blackwell-test. This PR should not have any effect on that test since the attn backend is totally different.

I got all the tests passed locally when I built the full wheel from source using Incremental Compilation Workflow. If I install vllm from source using VLLM_USE_PRECOMPILED=1, then I could get the same failure as CI hit.
https://buildkite.com/vllm/ci/builds/29725/steps/canvas?sid=019924e5-1b54-48cd-abab-1f6afb4668cc#019924e5-1c9e-428d-ba51-648d0c091c42/102-1113

>           assert cos_diff < 1e-4
E           assert 1.0 < 0.0001
tests/kernels/test_cutlass_mla_decode.py:22: AssertionError
======================================================================= warnings summary

@ProExpertProg Do you have any suggestions for resolving this? It is weird that I found other PRs could pass the tests on CI. Thanks!

@ProExpertProg
Copy link
Collaborator

What happens if you build with uv pip install but with vllm_use_precompiled=0?

@elvischenv
Copy link
Contributor Author

elvischenv commented Sep 8, 2025

What happens if you build with uv pip install but with vllm_use_precompiled=0?

Could pass all the tests with this PR:

====== 48 passed, 1 warning in 23.99s ======

@elvischenv
Copy link
Contributor Author

@vllm-bot vllm-bot merged commit bba1042 into vllm-project:main Sep 9, 2025
38 of 40 checks passed
@elvischenv elvischenv deleted the elvischenv/trtllm-fp8-q-bf16-out-attn branch September 9, 2025 05:07
@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Sep 9, 2025

EDIT: I had to upgrade flashinfer.


This PR broke test_attn_fusion.py for me locally:

FAILED tests/compile/test_fusion_attn.py::test_attention_quant_pattern[_Backend.FLASHINFER-nvidia/Llama-4-Scout-17B-16E-Instruct-FP8-TestAttentionFp8StaticQuantPatternModel-dtype0-7-128-64-8] - ValueError: Invalid dtype of out: expected torch.float8_e4m3fn, got torch.bfloat16

eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
…n Kernel (vllm-project#23647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@gau-nernst
Copy link
Contributor

@elvischenv On devices that FlashInfer uses FA2 backend (e.g. SM120), using FP8 query is not supported.

https://github.com/flashinfer-ai/flashinfer/blob/f7defbf9905b0d897c59afe84763fb6752ca8c22/flashinfer/jit/attention/pytorch.py#L519

Defaulting to FP8 query prevents me from enabling FP8 KV-cache for SM120. I see there are 2 possible solutions for this

  1. Introduce a new flag whether to use FP8 query. I feel using FP8 query should be explicit, because it may introduce degradation + diverge from other attention backend
  2. Add more checks to determine whether to use FP8 query by default (if KV-cache dtype is FP8). Maybe it's just a matter enabling FP8 query only when prefill_use_trtllm = True?

Anyway, I'm more than happy to work with you to go with either solutions. I have confirmed locally on my machine that simply setting q_data_type = torch.bfloat16 and allow FP8 KV-cache for FlashInfer attention in vllm.platforms.cuda works perfectly.

skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
…n Kernel (vllm-project#23647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…n Kernel (vllm-project#23647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…n Kernel (vllm-project#23647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…n Kernel (vllm-project#23647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants