Skip to content

Conversation

noooop
Copy link
Contributor

@noooop noooop commented Sep 10, 2025

Purpose

Follow #23810

"head" refers to the last Linear layer(s) of an LLM, such as the lm_head in a generation model, or the score or classifier in a classification model.

An increasing amount of evidence suggests that using an fp32 head can improve numerical precision.

From https://www.arxiv.org/pdf/2506.13585#page=7.62
Through layer-by-layer analysis, we identified high-magnitude activations in the LM head at the output layer as the primary source of error. To address this, we increased the precision of the LM output head to FP32, thereby realigning the two theoretically identical probabilities

pooling models part PTAL #23810

Fix #19925

needs modified

  1. tie_word_embeddings <- The biggest problem currently
  2. Support for quantized models
  3. hidden_states = hidden_states.to(self.head_dtype)
  4. logits_processor and samplers need to support head_dtype (May already be supported

Test Plan

VLLM_CI_HF_DTYPE="float32" VLLM_CI_HEAD_DTYPE="model" pytest -s -vvv tests/models/language/generation_ppl_test/test_qwen.py::test_ppl[model_info0]

Model: Qwen/Qwen3-0.6B
VLLM: dtype:torch.bfloat16 head_dtype:torch.bfloat16 23.85902976989746
Transformers: torch.float32 23.854570388793945
Difference (%): 0.01869403234195527
PASSED



VLLM_CI_HF_DTYPE="float32" VLLM_CI_HEAD_DTYPE="float32" pytest -s -vvv tests/models/language/generation_ppl_test/test_qwen.py::test_ppl[model_info0]

Model: Qwen/Qwen3-0.6B
VLLM: dtype:torch.bfloat16 head_dtype:torch.float32 23.852705001831055
Transformers: torch.float32 23.854570388793945
Difference (%): -0.007819830466395318
PASSED       


VLLM_CI_HF_DTYPE="float32" VLLM_CI_DTYPE="float32" pytest -s -vvv tests/models/language/generation_ppl_test/test_qwen.py::test_ppl[model_info0]

Model: Qwen/Qwen3-0.6B
VLLM: dtype:torch.float32 head_dtype:torch.float32 23.854564666748047
Transformers: torch.float32 23.854570388793945
Difference (%): -2.398721001961754e-05
PASSED     


Test Result

PPL test PTAL #24485

The smaller, the better.
Negative values mean that the vllm result is smaller than Transformers:float32

Model Transformers:float32 VLLM Default VLLM:float32 head VLLM:float32
openai-community/gpt2-large 19.44617462 0.049% 0.045% 0.000%
google/gemma-2b 21.51252747 -0.100% -0.120% 0.033%
Qwen/Qwen3-0.6B 23.85457039 0.019% -0.008% 0.000%

Using a float32 head is indeed better for the PPL test, sometimes even better than using float32 for all parameters.

Discussion

Ultimate question: Is this difference Really Matter in RLHF?

cc @22quinn @houseroad @yeqcharlotte @hijkzzz


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the qwen Related to Qwen models label Sep 10, 2025
@noooop
Copy link
Contributor Author

noooop commented Sep 10, 2025

cc @22quinn

I was wondering if using an fp32 head is critical for the success of RLHF training.

How can we construct a test to verify it


I'm not very familiar with RLHF.

I came up with a simple estimation method.
According to the PPL test above , a float32 head (vs model dtype head) roughly affects 0.002~0.0002 (20 * 0.01%~0.001%) .
May I ask, when RLHF converges, what is the approximate order of magnitude for PPL? variance of PPL?

(my proposed estimation method maybe completely nonsense

@noooop
Copy link
Contributor Author

noooop commented Sep 10, 2025

cc: @houseroad @yeqcharlotte @hijkzzz

Welcome to the discussion

Copy link

mergify bot commented Sep 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @noooop.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 10, 2025
@22quinn 22quinn added the rl Related to RL workflows label Sep 11, 2025
@22quinn
Copy link
Collaborator

22quinn commented Sep 11, 2025

Both fp32 LM head and perplexity eval were in my backlog :) Thanks a lot for this!

I think there are a few aspects here:

  1. fp32 LM head -> better perplexity: Your test result validated this
  2. better perplexity -> better RL outcome? And your detailed questions on what perplexity value is good for RL?
  3. Direct implication: fp32 LM head -> better RL outcome?

For 2), I'm not qualified to answer this before we get more data internally. cc @zhuohan123 in case any prior experience
For 3), FlashRL people have an awesome note: https://fengyao.notion.site/off-policy-rl However, I don't know if they did an ablation study on LM head alone. cc @yaof20 @LiyuanLucasLiu

@noooop noooop force-pushed the generation_fp32_head branch from 6801a36 to e53b713 Compare September 12, 2025 03:14
@noooop
Copy link
Contributor Author

noooop commented Sep 12, 2025

In Off-Policy RL, the LM head uses W16A16 (model dtype bfloat16), W16A32 (model dtype bfloat16 + out_dtype float32), W32A32 (LM head native float32 weight) -> will it affect the RL outcome?

@yaof20
Copy link

yaof20 commented Sep 12, 2025

Both fp32 LM head and perplexity eval were in my backlog :) Thanks a lot for this!

I think there are a few aspects here:

  1. fp32 LM head -> better perplexity: Your test result validated this
  2. better perplexity -> better RL outcome? And your detailed questions on what perplexity value is good for RL?
  3. Direct implication: fp32 LM head -> better RL outcome?

For 2), I'm not qualified to answer this before we get more data internally. cc @zhuohan123 in case any prior experience For 3), FlashRL people have an awesome note: https://fengyao.notion.site/off-policy-rl However, I don't know if they did an ablation study on LM head alone. cc @yaof20 @LiyuanLucasLiu

Hi, thanks for asking! Based on our experimental results, we found using FP32 LM head is a good-to-have feature but may not necessarily bring performance gain nor fix the rollout-training mismatch gap. Here are some key observations to share:

1. MiniMax-M1 Observations
The MiniMax-M1 technical report notes that using an FP32 LM head helps align training and rollout probabilities, thereby stabilizing RL training. It's worth noting that MiniMax-M1 is a linear-attention model, where activation values tend to be large before the LM head. In such cases, switching to FP32 is somehow sufficient to compensate for numerical instability.
image

2. Our Findings on DAPO-32B
In our experiments with DAPO-32B in our blog, applying an FP32 LM head (red line in Figure 1) did not close the rollout-training mismatch gap, nor did it improve downstream task performance. On the contrary, it introduced more memory usage, as also indicated by vLLM post. Given the trade-offs, we decided not to proceed further with FP32 LM head in this setup.
image

Therefore, while FP32 LM head may help in some architectures (like linear attention), its benefits appear context-dependent not universally applied. More broadly, we still need better understanding and more data to conclude whether improved perplexity directly correlates with better RL outcomes.

@noooop
Copy link
Contributor Author

noooop commented Sep 12, 2025

Thank you for sharing

  1. Our Findings on DAPO-32B

I did not find "the Figure 1" in the DAPO paper.

Could you please show me where I can see it?

@yaof20
Copy link

yaof20 commented Sep 12, 2025

Thank you for sharing

  1. Our Findings on DAPO-32B

I did not find "the Figure 1" in the DAPO paper.

Could you please show me where I can see it?

Hi, I guess you mean the Figure 1 I showed in my previous reply? If it is, the Figure 1 is from our blog here: https://fengyao.notion.site/off-policy-rl, where the red line in Figure 1 indicates using FP32 LM head. :)

@noooop
Copy link
Contributor Author

noooop commented Sep 12, 2025

Thank you for sharing

  1. Our Findings on DAPO-32B

I did not find "the Figure 1" in the DAPO paper.
Could you please show me where I can see it?

Hi, I guess you mean the Figure 1 I showed in my previous reply? If it is, the Figure 1 is from our blog here: https://fengyao.notion.site/off-policy-rl, where the red line in Figure 1 indicates using FP32 LM head. :)

Thanks

@noooop
Copy link
Contributor Author

noooop commented Sep 12, 2025

@yaof20

https://github.com/yaof20/Flash-RL/blob/07f3c21628125bcf609d18f6e854171ec17b984c/flash_rl/vllm_patch.py#L791-L804

Stupid question:

  • Is model_to_be_reloaded weight using bf16 or fp32 here?
  • Does VERL use mixed precision, and does it save the fp32 master copy?
  • Did the experiment use the W16A32 mode or the W32A32 mode (load native fp32 directly from fp32 master copy)?
  • Can using W32A32 mode make it slightly better?

(By the way, Qwen 32B does not use tie_word_embeddings, This makes many things simpler.

@noooop noooop force-pushed the generation_fp32_head branch from e53b713 to c694835 Compare September 12, 2025 09:03
Signed-off-by: wang.yuqi <noooop@126.com>
@mergify mergify bot removed the needs-rebase label Sep 12, 2025
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: wang.yuqi <noooop@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
qwen Related to Qwen models rl Related to RL workflows
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Support casting lm_head to FP32 to get old logprobs in RLHF
4 participants