Skip to content

Conversation

noooop
Copy link
Contributor

@noooop noooop commented Aug 28, 2025

Purpose

"head" refers to the last Linear layer(s) of an LLM, such as the lm_head in a generation model, or the score or classifier in a classification model.

An increasing amount of evidence suggests that using an fp32 head can improve numerical precision.

From https://www.arxiv.org/pdf/2506.13585#page=7.62
Through layer-by-layer analysis, we identified high-magnitude activations in the LM head at the output layer as the primary source of error. To address this, we increased the precision of the LM output head to FP32, thereby realigning the two theoretically identical probabilities

I think I found that issue. On main branch, the data passed to the activation function is in float32, but in this PR it is in the models' dtype. Let's see if using float32 fixes the problem

  • reward models

# TODO: The original reward weights have float32 accuracy data, we
# would like to load them in fp32 to get that extra precision.
# Currently weight_loader passes the weight which is already in bf16
self.score = nn.Linear(
config.hidden_size,
num_labels,
bias=score_bias,
dtype=torch.float32,
)

Let's systematic support for fp32 head

  1. Add VLLM_USING_FP32_HEAD, 1 for enable, 0 for disable, "" for default
  2. Add head_type in hf_config, let users to use --hf-overrides
  3. The pooling model defaults to using fp32 head, which does not significantly increase computation.
  4. The generation model defaults to not using fp32 head, Because the lm_head shape is [hidden_size, vocab_size]; set VLLM_USING_FP32_HEAD to enable it.

The generation support will be implemented in the next PR.

cc @DarkLight1337 @maxdebayser

Test Plan

keep ci green.

MTEB test

Test Result

MTEB test, The higher the score, the better. Difference: st_main_score - vllm_main_score

A negative diff indicates that the vllm default dtype mteb test performs better than SentenceTransformers with torch.float32.
(Is it not entirely impossible, right?

VLLM_HEAD_DTYPE="float32" pytest -s -vvv tests/models/language/pooling/test_st_projector.py
Model: TencentBAC/Conan-embedding-v1
VLLM: torch.float16 0.6886466841604848
SentenceTransformers: Constant 0.688611955
Difference: -3.472916048474772e-05 <- better 


VLLM_HEAD_DTYPE="model" pytest -s -vvv tests/models/language/pooling/test_st_projector.py
Model: TencentBAC/Conan-embedding-v1
VLLM: torch.float16 0.688634711512463
SentenceTransformers: Constant 0.688611955
Difference: -2.275651246297361e-05
VLLM_HEAD_DTYPE="float32" pytest -s -vvv tests/models/language/pooling/test_qwen3_reranker.py
Model: Qwen/Qwen3-Reranker-0.6B
VLLM: torch.bfloat16 0.25828
SentenceTransformers: torch.float32 0.25782
Difference: -0.00046000000000001595 <- better 


VLLM_HEAD_DTYPE="model" pytest -s -vvv tests/models/language/pooling/test_qwen3_reranker.py
Model: Qwen/Qwen3-Reranker-0.6B
VLLM: torch.bfloat16 0.25685
SentenceTransformers: torch.float32 0.25782
Difference: 0.0009699999999999709
VLLM_HEAD_DTYPE="float32" pytest -s -vvv tests/models/language/pooling/test_cross_encoder.py
VLLM: torch.float16 0.32884
SentenceTransformers: torch.float32 0.3288
Difference: -4.0000000000040004e-05 <- Same

Model: tomaarsen/Qwen3-Reranker-0.6B-seq-cls
VLLM: torch.float16 0.25783
SentenceTransformers: torch.float32 0.25782
Difference: -1.0000000000010001e-05 <- better 


VLLM_HEAD_DTYPE="model" pytest -s -vvv tests/models/language/pooling/test_cross_encoder.py
Model: cross-encoder/ms-marco-TinyBERT-L-2-v2
VLLM: torch.float16 0.32884
SentenceTransformers: torch.float32 0.3288
Difference: -4.0000000000040004e-05

Model: tomaarsen/Qwen3-Reranker-0.6B-seq-cls
VLLM: torch.float16 0.25782
SentenceTransformers: torch.float32 0.25782
Difference: 0.0

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

This mteb test can distinguish whether fp32 head is used, give me an emoji if you think it's cool.

↓↓↓↓↓↓↓↓↓↓

Signed-off-by: wang.yuqi <noooop@126.com>
@mergify mergify bot added the qwen Related to Qwen models label Aug 28, 2025
Copy link

mergify bot commented Aug 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 29, 2025
Copy link
Contributor

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely an improvement, thanks. I've left suggestion regarding the ModelConfig changes.

Signed-off-by: wang.yuqi <noooop@126.com>
@mergify mergify bot removed the needs-rebase label Sep 3, 2025
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop
Copy link
Contributor Author

noooop commented Sep 3, 2025

@DarkLight1337 @maxdebayser

Ready for review

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Copy link

mergify bot commented Sep 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 9, 2025
Signed-off-by: wang.yuqi <noooop@126.com>
@mergify mergify bot removed the needs-rebase label Sep 9, 2025
Signed-off-by: wang.yuqi <noooop@126.com>
@noooop
Copy link
Contributor Author

noooop commented Sep 9, 2025

@DarkLight1337

Please take a look at this thread

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 9, 2025
@vllm-bot vllm-bot merged commit 19332c0 into vllm-project:main Sep 9, 2025
43 of 46 checks passed
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
@noooop noooop deleted the pooling_fp32_head branch September 10, 2025 06:10
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
@noooop noooop restored the pooling_fp32_head branch September 13, 2025 07:46
rogeryoungh pushed a commit to MiniMax-AI/vllm that referenced this pull request Sep 15, 2025
…roject#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: rogeryoungh <rogeryoungh@foxmail.com>
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
…roject#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: bruceszchen <bruceszchen@tencent.com>
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
…roject#23810)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants