Skip to content

Conversation

@elvischenv
Copy link
Contributor

@elvischenv elvischenv commented Aug 1, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Previously #19825 supported TRTLLM attn kernel for decode code path, this PR is aiming to support the prefill path.

  • Unified the decode and prefill code path, use only one env VLLM_USE_TRTLLM_ATTENTION to control
  • Currently the TRTLLM prefill kernel only does not support Q-BF16 KV-FP8 O-BF16 config, will directly use Q-FP8 KV-FP8 O-FP8/O-NVFP4 after the attn+quant fusion is supported.
  • Generally perf of TRTLLM prefill kernels seems to be much better than the original kernels.

Test Plan && Test Result

tests/kernels/attention/test_flashinfer_trtllm_attention.py

==== 160 passed, 32 skipped, 129 warnings in 26.46s ====

lm_eval

vllm ({'pretrained': '/home/scratch.omniml_data_2/HF_model_hub/Llama-4-Scout-17B-16E-Instruct-FP8', 'quantization': 'modelopt', 'kv_cache_dtype': 'auto', 'tensor_parallel_size': 1, 'compilation_config': {'level': 3, 'custom_ops': ['+rms_norm'], 'pass_config': {'enable_fi_allreduce_fusion': True}, 'full_cuda_graph': True}, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.936|±  |0.0110|
|     |       |strict-match    |     5|exact_match|↑  |0.914|±  |0.0126|

benchmarks/kernels/benchmark_trtllm_prefill_attention.py

Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, output_dtype: bfloat16
num_seqs  max_seq_len   trt_mean  trt_std baseline_mean  baseline_std   speedup_percent
1         2048          0.102     0.004   0.233          0.010          0.562
4         2048          0.204     0.005   0.377          0.011          0.459
8         2048          0.284     0.004   0.457          0.012          0.379
16        2048          0.660     0.004   0.995          0.012          0.337
32        2048          1.009     0.006   1.446          0.010          0.302
64        2048          2.104     0.007   2.920          0.013          0.280
128       2048          4.416     0.010   6.033          0.011          0.268
256       2048          8.186     0.012   10.907         0.010          0.249
1         4096          0.259     0.005   0.699          0.013          0.629
4         4096          0.521     0.007   1.175          0.013          0.556
8         4096          0.654     0.006   1.397          0.013          0.531
16        4096          1.698     0.009   3.648          0.013          0.534
32        4096          2.840     0.007   5.510          0.011          0.485
64        4096          5.057     0.041   9.810          0.016          0.484
128       4096          10.667    0.053   19.737         0.014          0.460
256       4096          23.228    0.425   42.943         0.024          0.459
1         8192          0.784     0.006   2.420          0.010          0.676
4         8192          1.093     0.005   3.034          0.019          0.640
8         8192          4.028     0.118   10.455         0.014          0.615
16        8192          4.767     0.073   11.708         0.015          0.593
32        8192          7.630     0.066   17.438         0.012          0.562
64        8192          13.507    0.062   30.210         0.014          0.553
128       8192          34.001    0.397   76.420         0.043          0.555
256       8192          69.327    0.358   155.595        0.047          0.554
1         16384         2.730     0.012   8.955          0.020          0.695
4         16384         5.805     0.115   16.405         0.020          0.646
8         16384         8.579     0.309   23.654         0.015          0.637
16        16384         12.934    0.261   33.790         0.023          0.617
32        16384         32.349    0.381   84.966         0.040          0.619
64        16384         58.197    0.268   150.340        0.029          0.613
128       16384         107.968   0.840   277.745        0.042          0.611
256       16384         259.083   0.478   677.202        0.055          0.617
1         32768         11.562    0.097   34.844         0.016          0.668
4         32768         18.377    0.344   53.366         0.014          0.656
8         32768         37.628    0.384   106.451        0.020          0.647
16        32768         49.803    0.274   138.859        0.027          0.641
32        32768         129.513   0.203   365.814        0.041          0.646
64        32768         251.496   0.676   709.115        0.025          0.645
128       32768         445.987   0.357   1253.564       0.024          0.644
256       32768         823.417   0.697   2299.850       0.053          0.642
1         65536         45.538    0.384   137.810        0.026          0.670
4         65536         58.513    0.373   174.282        0.019          0.664
8         65536         120.448   0.207   355.113        0.027          0.661
16        65536         216.930   0.376   640.438        0.026          0.661
32        65536         435.518   0.462   1280.544       0.025          0.660
64        65536         896.008   0.783   2625.461       0.047          0.659
128       65536         1748.643  1.167   5130.747       0.064          0.659
256       65536         OOM
1         131072        177.47513 0.70088 548.54210      0.06138        0.67646
4         131072        214.06086 0.55788 657.87288      0.05146        0.67462
8         131072        692.82944 1.23500 2120.49780     0.07645        0.67327
16        131072        619.11074 0.60123 1872.61736     0.06045        0.66939
32        131072        1526.7868 1.12591 4600.27112     0.09479        0.66811
64        131072        3330.5536 2.09024 10066.05156    0.12476        0.66913
...OOM

(Optional) Documentation Update

@github-actions
Copy link

github-actions bot commented Aug 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added performance Performance-related issues v1 labels Aug 1, 2025
@gemini-code-assist
Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@mergify
Copy link

mergify bot commented Aug 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @elvischenv.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 2, 2025
@elvischenv elvischenv force-pushed the elvischenv/flashinfer-prefill-trtllm-attn branch from 2a87f39 to 21aede2 Compare August 4, 2025 05:40
@mergify mergify bot removed the needs-rebase label Aug 4, 2025
@nvpohanh
Copy link
Contributor

nvpohanh commented Aug 4, 2025

Overall looks good to me. Thanks!

@pavanimajety Please also help to review this.

@elvischenv elvischenv force-pushed the elvischenv/flashinfer-prefill-trtllm-attn branch from 21aede2 to 7237472 Compare August 4, 2025 12:51
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still true? Please update the comment if more head group sizes are supported and change the logic for the head group ratio in use_trtllm_attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed in the latest commit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have the cubins been updated to support both layouts? In that case, we may want to remove the default HND restriction placed for SM100

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think flashinfer still have this constraint. The unit test in flashinfer still tests HND only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may have to return false for use_trtllm_attention when window_left is non default value as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the latest commit.

Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR, @elvischenv. Left some minor feedback comments.

@pavanimajety
Copy link
Collaborator

@elvischenv Does this PR have full cuda graph support?

@pavanimajety
Copy link
Collaborator

@mgoin Can you please review too? Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update Blackwell Test in .buildkite/test-pipeline.yaml to include this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed in the latest commit.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work, this looks good to me. Will try to smoke test when I get access to B200 again

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@elvischenv elvischenv force-pushed the elvischenv/flashinfer-prefill-trtllm-attn branch from 7237472 to 1918711 Compare August 5, 2025 03:49
@mergify mergify bot added the ci/build label Aug 5, 2025
@elvischenv
Copy link
Contributor Author

@elvischenv Does this PR have full cuda graph support?

We have tested with full cuda graph and seem it works.

# currently prefill trtllm attention does not support fp8 kv cache
# trtllm may not support sliding window
prefill_use_trtllm = (self.global_hyperparameters.window_left == -1
and not cache_dtype.startswith("fp8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I believe this is already checked in use_trtllm_attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean not cache_dtype.startswith("fp8")? use_trtllm_attention can be overwritten by VLLM_USE_TRTLLM_ATTENTION=1. With VLLM_USE_TRTLLM_ATTENTION=1, we still cannot use TRTLLM FP8-kv since there is no BF16-q FP8-kv kernel.

This is just a WAR for now. Will update this after the FP8-q FP8-kv kernel is supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer that we clean these up after we have the Attn+FP8/FP4-Quant fusions. Things will be clearer when that part is done. Thanks!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 5, 2025
@vllm-bot vllm-bot merged commit 83156c7 into vllm-project:main Aug 5, 2025
81 of 84 checks passed
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@elvischenv elvischenv deleted the elvischenv/flashinfer-prefill-trtllm-attn branch August 7, 2025 00:12
myselvess pushed a commit to myselvess/vllm that referenced this pull request Aug 7, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: Noam Gat <noamgat@gmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…oject#22095)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@frank-wei
Copy link
Contributor

frank-wei commented Sep 2, 2025

@elvischenv @nvpohanh @pavanimajety do you have any update on support of FP8-q FP8-kv kernel?

@nvpohanh
Copy link
Contributor

nvpohanh commented Sep 3, 2025

@frank-wei FP8-QKV and FP8/FP4-output are already supported. FP8-QKV and BF16/FP16-output will be supported after #23647 is merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants