Skip to content

Conversation

noooop
Copy link
Contributor

@noooop noooop commented Aug 19, 2025

Purpose

V1 Pooling Models E2E Performance Optimization

#  https://github.com/noooop/snippet/blob/main/benchmarks/embed/profile.py
VLLM_TORCH_PROFILER_DIR=/xxx python profile.py

main CPU: 67.134ms CUDA: 46.417ms -> this pr CPU: 39.740ms CUDA: 39.890ms

benchmarks:

Result:

  • X-axis: Throughput (token/s)
  • Y-axis: Latency, Time needed for one step (ms)
  • The curve lower right is better ↘
image image image image image

Frankly speaking,
this optimizations only show significant improvements when handling many small requests—they have almost no impact on large requests (since the backbone network latency dominates in such cases).

Of course, a faster implementation is always cooler.

Please click here for long details.
↓↓↓↓↓↓↓↓↓

  1. replace_roberta_positions takes too long
image

6d8f55e: CPU: 67.134ms CUDA: 46.417ms -> CPU: 53.261ms CUDA: 42.820ms

Correct Testing

  • pytest -s -vvv tests/models/language/pooling/test_baai.py
    • BAAI/bge-m3 for XLMRobertaModel
    • BAAI/bge-reranker-base for XLMRobertaForSequenceClassification
  1. reduce cuda sync
image

8ff2418: CPU: 53.261ms CUDA: 42.820ms -> CPU: 43.016ms CUDA: 42.516ms

  1. Use as many batch operations as possible
image

13e44df: CPU: 43.016ms CUDA: 42.516ms -> CPU: 42.463ms CUDA: 40.629ms

  1. non_blocking seq_lens (Use seq_lens_cpu directly later
image

876cb9a: CPU: 42.463ms CUDA: 40.629ms -> CPU: 39.934ms CUDA: 40.128ms

  1. remove prompt_len == hidden_states.shape[0] cuda sync by using prompt_len (cpu tenser) == num_scheduled_tokens (cpu tenser)
image

93d0a95: CPU: 39.934ms CUDA: 40.128ms -> CPU: 39.509ms CUDA: 40.094ms

  1. _decode_token_type_ids uses ones_like to avoid getting shape which causes cuda sync
image

f649899: CPU: 39.509ms CUDA: 40.094ms -> CPU: 39.317ms CUDA: 40.109ms

  1. non_blocking torch.zeros in _build_encoder_only_attn_metadata
image

d7fa9a8: CPU: 39.317ms CUDA: 40.109ms -> CPU: 39.182ms CUDA: 39.981ms

  1. using pooling_cursor Rather than torch.split(hidden_states, num_scheduled_tokens_list)
image

2f10806: CPU: 39.182ms CUDA: 39.981ms -> CPU: 39.740ms CUDA: 39.890ms

Test Plan

keep CI green

Test Result

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: wang.yuqi <noooop@126.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: wang.yuqi <noooop@126.com>
@mergify mergify bot added the v1 label Aug 19, 2025
noooop added 2 commits August 19, 2025 17:23
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop noooop changed the title [Model] Pooling Models E2E Performance Optimization [Performance] Pooling Models E2E Performance Optimization Aug 20, 2025
noooop added 10 commits August 20, 2025 14:51
Signed-off-by: wang.yuqi <noooop@126.com>
…uses cuda sync

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop noooop marked this pull request as ready for review August 20, 2025 13:46
Copy link

dosubot bot commented Aug 20, 2025

Related Documentation

No published documentation to review for changes on this repository.
Write your first living document

How did I do? Any feedback?  Join Discord

@maxdebayser
Copy link
Contributor

@noooop , nice work! There are two optimizations in this PR: the roberta position ids and removing the split of the hidden states. Do you have benchmark results on how much each individually changes the performance compared to the main branch?

@noooop
Copy link
Contributor Author

noooop commented Aug 20, 2025

@noooop , nice work! There are two optimizations in this PR: the roberta position ids and removing the split of the hidden states. Do you have benchmark results on how much each individually changes the performance compared to the main branch?

Please click detail for long details.

@noooop
Copy link
Contributor Author

noooop commented Aug 20, 2025

these optimizations only show significant improvements when handling many small requests—they have almost no impact on large requests (since the backbone network latency dominates in such cases).

@noooop noooop changed the title [Performance] Pooling Models E2E Performance Optimization [Performance] V1 Pooling Models E2E Performance Optimization Aug 21, 2025
@noooop
Copy link
Contributor Author

noooop commented Aug 21, 2025

@DarkLight1337

By the way, since #22878 Pooling models mteb test uses enforce_eager, mteb_test_embed_models has become less (never) flaky.

I'll reset the threshold back to MTEB_EMBED_TOL = 1e-4 next month.

image

I'm not 100% sure that flaky is caused by torch.compile. It is very likely caused by it.

cc @robertgshaw2-redhat

Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests pass so this should be good to go, thanks!

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) August 21, 2025 11:30
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 21, 2025
@DarkLight1337 DarkLight1337 merged commit d70a166 into vllm-project:main Aug 21, 2025
60 checks passed
@noooop noooop deleted the pooling_e2e branch August 21, 2025 13:43
PapaGoose pushed a commit to PapaGoose/vllm that referenced this pull request Aug 21, 2025
…oject#23162)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Danila Kirichko <d.kirichko@mts.ai>
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
…oject#23162)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Duncan Moss <djm.moss@gmail.com>
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Aug 21, 2025
…oject#23162)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Boyuan Feng <boyuan@meta.com>
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 22, 2025
Xu-Wenqing pushed a commit to Xu-Wenqing/vllm that referenced this pull request Aug 23, 2025
…oject#23162)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: root <xwq391974@alibaba-inc.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
juuice-lee pushed a commit to juuice-lee/vllm-moe.code that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…oject#23162)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…oject#23162)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
dumb0002 pushed a commit to dumb0002/vllm that referenced this pull request Aug 28, 2025
2015aroras pushed a commit to 2015aroras/vllm that referenced this pull request Aug 29, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
langc23 pushed a commit to zte-riscv/vllm that referenced this pull request Sep 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants