Skip to content

Add optional RMSNorm support to BitNet quantization (config + layers) #38087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 16, 2025
4 changes: 3 additions & 1 deletion src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,7 +1584,9 @@ def __init__(
self.pattern = pattern
self.add_prefix_space = add_prefix_space
self.additional_special_tokens = (
additional_special_tokens.keys() if type(additional_special_tokens) is dict else additional_special_tokens
additional_special_tokens.keys()
if isinstance(additional_special_tokens, dict)
else additional_special_tokens
)

def extract_vocab_merges_from_model(self, tiktoken_url: str):
Expand Down
40 changes: 38 additions & 2 deletions src/transformers/integrations/bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,16 @@ def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:


class BitLinear(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
device=None,
dtype=None,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
):
super().__init__()
self.dtype = dtype
self.in_features = in_features
Expand All @@ -150,6 +159,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool, device=None,
else:
self.bias = None

# Optional RMSNorm (applied on the activations before quantization).
self.rms_norm = None
if use_rms_norm:
from ..models.llama.modeling_llama import LlamaRMSNorm

self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)

@torch.compile
def activation_quant(self, input, num_bits=8):
"""
Expand Down Expand Up @@ -180,6 +196,10 @@ def post_quant_process(self, input, input_scale, weight_scale):
return out

def forward(self, input):
# Apply RMSNorm on the input if requested.
if self.rms_norm is not None:
input = self.rms_norm(input)

w = self.weight
w_quant = unpack_weights(w, dtype=self.dtype)
input_quant, input_scale = self.activation_quant(input)
Expand Down Expand Up @@ -245,9 +265,17 @@ def __init__(
device=None,
dtype=None,
online_quant: bool = False,
use_rms_norm: bool = False,
rms_norm_eps: float = 1e-6,
):
super().__init__(in_features, out_features, bias)
self.online_quant = online_quant
# Optional RMSNorm
self.rms_norm = None
if use_rms_norm:
from ..models.llama.modeling_llama import LlamaRMSNorm

self.rms_norm = LlamaRMSNorm(in_features, eps=rms_norm_eps)
if not online_quant:
self.register_buffer(
"weight_scale",
Expand All @@ -271,6 +299,10 @@ def load_hook(
return state_dict

def forward(self, input):
# Optional RMSNorm on activations prior to quantization.
if self.rms_norm is not None:
input = self.rms_norm(input)

if self.online_quant:
weight = WeightQuant.apply(self.weight)
else:
Expand Down Expand Up @@ -318,6 +350,8 @@ def _replace_with_bitnet_linear(
device=module.weight.device,
dtype=module.weight.dtype,
online_quant=(quantization_config.quantization_mode == "online"),
use_rms_norm=quantization_config.use_rms_norm,
rms_norm_eps=quantization_config.rms_norm_eps,
)
if quantization_config.quantization_mode == "offline":
model._modules[name].requires_grad_(False)
Expand All @@ -328,6 +362,8 @@ def _replace_with_bitnet_linear(
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
use_rms_norm=quantization_config.use_rms_norm,
rms_norm_eps=quantization_config.rms_norm_eps,
)
model._modules[name].requires_grad_(False)
has_been_replaced = True
Expand Down Expand Up @@ -363,7 +399,7 @@ def replace_with_bitnet_linear(
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
Names of the modules to not convert in `BitLinear`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,11 @@ class BitNetQuantConfig(QuantizationConfigMixin):
In `offline` mode, quantization parameters are pre-calculated *before* inference.
These parameters are then fixed and loaded into the quantized model. This
generally results in lower runtime overhead compared to online quantization.
use_rms_norm (`bool`, *optional*, defaults to `False`):
Whether to apply RMSNorm on the activations before quantization. This matches the original BitNet paper's approach
of normalizing activations before quantization/packing.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon value used in the RMSNorm layer for numerical stability.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments that may be used by specific quantization
backends or future versions.
Expand All @@ -1801,6 +1806,8 @@ def __init__(
modules_to_not_convert: Optional[List] = None,
linear_class: Optional[str] = "bitlinear",
quantization_mode: Optional[str] = "offline",
use_rms_norm: Optional[bool] = False,
rms_norm_eps: Optional[float] = 1e-6,
**kwargs,
):
if linear_class not in ["bitlinear", "autobitlinear"]:
Expand All @@ -1811,6 +1818,8 @@ def __init__(
self.modules_to_not_convert = modules_to_not_convert
self.linear_class = linear_class
self.quantization_mode = quantization_mode
self.use_rms_norm = use_rms_norm
self.rms_norm_eps = rms_norm_eps
self.post_init()

def post_init(self):
Expand Down