Skip to content

Add optional RMSNorm support to BitNet quantization (config + layers) #38087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 16, 2025

Conversation

Codys12
Copy link
Contributor

@Codys12 Codys12 commented May 12, 2025

What does this PR do?

Adds optional RMSNorm support to BitNet-style quantisation.

  • Introduces use_rms_norm (bool, default False) and rms_norm_eps (float, default 1e-6) to BitNetQuantConfig so the flag is serialisable via save_pretrained / from_pretrained.
  • Updates BitLinear and AutoBitLinear to accept use_rms_norm and apply the reference BitNetRMSNorm to activations before quantisation.

Before submitting

  • I read the contributor guideline, PR section.
  • I’ve added the new config fields to to_dict, docstrings, and the model card.
  • New unit tests
  • Ran make style && make quality && make test locally.
  • Documentation build passes (make docs) – pushed logs to CI.

Motivation and context

RMSNorm stabilises the activations of low-bit networks; the BitNet paper shows a consistent perplexity drop when normalising pre-quant activations. This PR brings parity with the reference implementation while keeping the previous behaviour as default.

No new external dependencies.

Who can review?

Quantization / Accelerate folks for the code:
@SunMarc @MekkCyber

Docstrings & config: @stevhliu

Feel free to jump in with any feedback!

@github-actions github-actions bot marked this pull request as draft May 12, 2025 15:15
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you share a bit about motivation for such feature ?

@Codys12
Copy link
Contributor Author

Codys12 commented May 12, 2025

@SunMarc
Sure! I recently discovered that initializing an extra RMS norm in each Linear lets you finetune existing LLMs to BitNet format and tested this for https://huggingface.co/codys12/bitnet-r1-32b and https://huggingface.co/codys12/bitnet-r1-8b. This process should work for any model, and this is a minimal implementation for easy conversion.

@MekkCyber
Copy link
Contributor

Hi @Codys12, thanks for the pr 🤗 ! I'm not sure I understand the idea behind adding an extra rmsnorm, in the current bitnet implementation, the authors already added that in the modelling, for example MLP looks like :

class BitNetMLP(GemmaMLP):
    def __init__(self, config: BitNetConfig):
        super().__init__(config)
        self.ffn_sub_norm = BitNetRMSNorm(config.intermediate_size, eps=config.rms_norm_eps)

    def forward(self, x):
        down_proj = self.down_proj(self.ffn_sub_norm(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
        return down_proj

@Codys12
Copy link
Contributor Author

Codys12 commented May 13, 2025

Hi @Codys12, thanks for the pr 🤗 ! I'm not sure I understand the idea behind adding an extra rmsnorm, in the current bitnet implementation, the authors already added that in the modelling, for example MLP looks like :

class BitNetMLP(GemmaMLP):
    def __init__(self, config: BitNetConfig):
        super().__init__(config)
        self.ffn_sub_norm = BitNetRMSNorm(config.intermediate_size, eps=config.rms_norm_eps)

    def forward(self, x):
        down_proj = self.down_proj(self.ffn_sub_norm(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
        return down_proj

@MekkCyber @SunMarc Good point!

In section 2/2.1 of the origininal BitNet paper (https://arxiv.org/pdf/2310.11453), the authors describe the reason for including the RMS: it improves performance of the models at negligable compute/parameter costs. They include it in their modeling file, but others (see here) have tested this with alternative architectures and observed an improvement (Llama, Mistral, DeepSeek V3, etc).

This change in the Quantization Config is a model-agnostic approach for introducing this parameter so that a new modeling_*.py file is not required for every model you want to test this way.

Additionally, the inclusion of this norm allows you to finetune existing models to this quantization (see here) as demonstrated by https://huggingface.co/codys12/bitnet-r1-32b and https://huggingface.co/codys12/bitnet-r1-8b.

Let me know if there is anything that is still unclear!

Copy link
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation @Codys12, I see the idea behind this !

@MekkCyber MekkCyber marked this pull request as ready for review May 13, 2025 13:59
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Just a couple of nits

@Codys12
Copy link
Contributor Author

Codys12 commented May 13, 2025

Thanks ! Just a couple of nits

@SunMarc Just made these changes, let me know if there is anything else I can do before merge!

@MekkCyber
Copy link
Contributor

Thanks @Codys12, can you please run make fix-copies and make style to fix CI 🤗

steinmetzc and others added 2 commits May 13, 2025 13:03
@Codys12
Copy link
Contributor Author

Codys12 commented May 13, 2025

@MekkCyber
make fix-copies is passing, but the CI tests are still failing... Should I add it to to OBJECTS_TO_IGNORE in utils/check_docstrings.py?

@Codys12
Copy link
Contributor Author

Codys12 commented May 14, 2025

@MekkCyber @SunMarc Any ideas on CI here? Looking to help this move forward today

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@SunMarc
Copy link
Member

SunMarc commented May 14, 2025

try to run make fix-copies but remove the default values for the arguments you added in BitNetQuantConfig. I think the issue is there.

Traceback (most recent call last):
  File "/root/transformers/utils/check_docstrings.py", line 1467, in <module>
    check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
  File "/root/transformers/utils/check_docstrings.py", line 1456, in check_docstrings
    raise ValueError(error_message)
ValueError: There was at least one problem when checking docstrings of public objects.
The following objects docstrings do not match their signature. Run `make fix-copies` to fix this. In some cases, this error may be raised incorrectly by the docstring checker. If you think this is the case, you can manually check the docstrings and then add the object name to `OBJECTS_TO_IGNORE` in `utils/check_docstrings.py`.
- BitNetQuantConfig

@Codys12
Copy link
Contributor Author

Codys12 commented May 14, 2025

try to run make fix-copies but remove the default values for the arguments you added in BitNetQuantConfig. I think the issue is there.

Traceback (most recent call last):
  File "/root/transformers/utils/check_docstrings.py", line 1467, in <module>
    check_docstrings(overwrite=args.fix_and_overwrite, check_all=args.check_all)
  File "/root/transformers/utils/check_docstrings.py", line 1456, in check_docstrings
    raise ValueError(error_message)
ValueError: There was at least one problem when checking docstrings of public objects.
The following objects docstrings do not match their signature. Run `make fix-copies` to fix this. In some cases, this error may be raised incorrectly by the docstring checker. If you think this is the case, you can manually check the docstrings and then add the object name to `OBJECTS_TO_IGNORE` in `utils/check_docstrings.py`.
- BitNetQuantConfig

@SunMarc Hmm, I changed it to optional but running make fix-copies is not doing anything. It might have to do with running the CI loop on another device that I am pulling my fork into... Do you have any experience with getting the formatting fixed?

@Codys12
Copy link
Contributor Author

Codys12 commented May 14, 2025

Wait, all tests are passing sick

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@MekkCyber MekkCyber merged commit 1e921a3 into huggingface:main May 16, 2025
20 checks passed
@MekkCyber
Copy link
Contributor

Thanks for the pr 🤗 !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants