Skip to content

Commit 78f3e08

Browse files
vasquArthurZucker
authored andcommitted
[RMSNorm] Fix rms norm init for models that center around 1 (#40796)
* fix * fixup inits * oops * fixup gemma * fixup modular order * how does this keep happen lol * vaultgemma is new i forgot * remove init check
1 parent a5ffae6 commit 78f3e08

File tree

12 files changed

+103
-37
lines changed

12 files changed

+103
-37
lines changed

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,13 @@ class GemmaPreTrainedModel(PreTrainedModel):
322322
"attentions": GemmaAttention,
323323
}
324324

325+
def _init_weights(self, module):
326+
super()._init_weights(module)
327+
328+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
329+
if "RMSNorm" in module.__class__.__name__:
330+
module.weight.data.zero_()
331+
325332

326333
@auto_docstring
327334
class GemmaModel(GemmaPreTrainedModel):

src/transformers/models/gemma/modular_gemma.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...configuration_utils import PretrainedConfig
2424
from ...masking_utils import create_causal_mask
2525
from ...modeling_outputs import BaseModelOutputWithPast
26+
from ...modeling_utils import PreTrainedModel
2627
from ...processing_utils import Unpack
2728
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
2829
from ...utils import TransformersKwargs, logging
@@ -32,6 +33,8 @@
3233
LlamaForTokenClassification,
3334
LlamaMLP,
3435
LlamaModel,
36+
LlamaPreTrainedModel,
37+
LlamaRotaryEmbedding,
3538
)
3639
from ..llama.tokenization_llama import LlamaTokenizer
3740

@@ -366,6 +369,19 @@ def __init__(self, config):
366369
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
367370

368371

372+
class GemmaRotaryEmbedding(LlamaRotaryEmbedding):
373+
pass
374+
375+
376+
class GemmaPreTrainedModel(LlamaPreTrainedModel):
377+
def _init_weights(self, module):
378+
PreTrainedModel._init_weights(self, module)
379+
380+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
381+
if "RMSNorm" in module.__class__.__name__:
382+
module.weight.data.zero_()
383+
384+
369385
class GemmaModel(LlamaModel):
370386
def forward(
371387
self,

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,42 @@ def forward(self, x):
8383
return down_proj
8484

8585

86+
class Gemma2RotaryEmbedding(nn.Module):
87+
inv_freq: torch.Tensor # fix linting for `register_buffer`
88+
89+
def __init__(self, config: Gemma2Config, device=None):
90+
super().__init__()
91+
# BC: "rope_type" was originally "type"
92+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
93+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
94+
else:
95+
self.rope_type = "default"
96+
self.max_seq_len_cached = config.max_position_embeddings
97+
self.original_max_seq_len = config.max_position_embeddings
98+
99+
self.config = config
100+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
101+
102+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
103+
self.register_buffer("inv_freq", inv_freq, persistent=False)
104+
self.original_inv_freq = self.inv_freq
105+
106+
@torch.no_grad()
107+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
108+
def forward(self, x, position_ids):
109+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
110+
position_ids_expanded = position_ids[:, None, :].float()
111+
112+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
113+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
114+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
115+
emb = torch.cat((freqs, freqs), dim=-1)
116+
cos = emb.cos() * self.attention_scaling
117+
sin = emb.sin() * self.attention_scaling
118+
119+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
120+
121+
86122
def rotate_half(x):
87123
"""Rotates half the hidden dims of the input."""
88124
x1 = x[..., : x.shape[-1] // 2]
@@ -299,42 +335,6 @@ def forward(
299335
return outputs
300336

301337

302-
class Gemma2RotaryEmbedding(nn.Module):
303-
inv_freq: torch.Tensor # fix linting for `register_buffer`
304-
305-
def __init__(self, config: Gemma2Config, device=None):
306-
super().__init__()
307-
# BC: "rope_type" was originally "type"
308-
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
309-
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
310-
else:
311-
self.rope_type = "default"
312-
self.max_seq_len_cached = config.max_position_embeddings
313-
self.original_max_seq_len = config.max_position_embeddings
314-
315-
self.config = config
316-
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
317-
318-
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
319-
self.register_buffer("inv_freq", inv_freq, persistent=False)
320-
self.original_inv_freq = self.inv_freq
321-
322-
@torch.no_grad()
323-
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
324-
def forward(self, x, position_ids):
325-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
326-
position_ids_expanded = position_ids[:, None, :].float()
327-
328-
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
329-
with torch.autocast(device_type=device_type, enabled=False): # Force float32
330-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
331-
emb = torch.cat((freqs, freqs), dim=-1)
332-
cos = emb.cos() * self.attention_scaling
333-
sin = emb.sin() * self.attention_scaling
334-
335-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
336-
337-
338338
@auto_docstring
339339
class Gemma2PreTrainedModel(PreTrainedModel):
340340
config: Gemma2Config
@@ -353,6 +353,13 @@ class Gemma2PreTrainedModel(PreTrainedModel):
353353
"attentions": Gemma2Attention,
354354
}
355355

356+
def _init_weights(self, module):
357+
super()._init_weights(module)
358+
359+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
360+
if "RMSNorm" in module.__class__.__name__:
361+
module.weight.data.zero_()
362+
356363

357364
@auto_docstring
358365
class Gemma2Model(Gemma2PreTrainedModel):

src/transformers/models/gemma2/modular_gemma2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
GemmaForTokenClassification,
3737
GemmaMLP,
3838
GemmaModel,
39+
GemmaPreTrainedModel,
3940
GemmaRMSNorm,
41+
GemmaRotaryEmbedding,
4042
apply_rotary_pos_emb,
4143
repeat_kv,
4244
)
@@ -212,6 +214,10 @@ def __init__(self, config):
212214
self.act_fn = ACT2FN[config.hidden_activation]
213215

214216

217+
class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
218+
pass
219+
220+
215221
def eager_attention_forward(
216222
module: nn.Module,
217223
query: torch.Tensor,
@@ -363,6 +369,10 @@ def forward(
363369
return outputs
364370

365371

372+
class Gemma2PreTrainedModel(GemmaPreTrainedModel):
373+
pass
374+
375+
366376
class Gemma2Model(GemmaModel):
367377
def __init__(self, config: Gemma2Config):
368378
super().__init__(config)

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ def _init_weights(self, module):
434434
super()._init_weights(module)
435435
if isinstance(module, Gemma3MultiModalProjector):
436436
module.mm_input_projection_weight.data.zero_()
437+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
438+
elif "RMSNorm" in module.__class__.__name__:
439+
module.weight.data.zero_()
437440

438441

439442
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,9 @@ def _init_weights(self, module):
526526
PreTrainedModel._init_weights(self, module)
527527
if isinstance(module, Gemma3MultiModalProjector):
528528
module.mm_input_projection_weight.data.zero_()
529+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
530+
elif "RMSNorm" in module.__class__.__name__:
531+
module.weight.data.zero_()
529532

530533

531534
def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:

src/transformers/models/qwen3_next/modeling_qwen3_next.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,9 @@ def _init_weights(self, module):
970970
if isinstance(module, Qwen3NextGatedDeltaNet):
971971
module.dt_bias.data.fill_(1.0)
972972
module.A_log.data.uniform_(0, 16).log_()
973+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
974+
elif isinstance(module, Qwen3NextRMSNorm):
975+
module.weight.data.zero_()
973976

974977

975978
class Qwen3NextModel(Qwen3NextPreTrainedModel):

src/transformers/models/qwen3_next/modular_qwen3_next.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,9 @@ def _init_weights(self, module):
709709
if isinstance(module, Qwen3NextGatedDeltaNet):
710710
module.dt_bias.data.fill_(1.0)
711711
module.A_log.data.uniform_(0, 16).log_()
712+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
713+
elif isinstance(module, Qwen3NextRMSNorm):
714+
module.weight.data.zero_()
712715

713716

714717
class Qwen3NextModel(Qwen3NextPreTrainedModel):

src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,9 @@ def _init_weights(self, module):
556556
if module.padding_idx is not None:
557557
module.weight.data[module.padding_idx].zero_()
558558

559+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
559560
elif isinstance(module, RecurrentGemmaRMSNorm):
560-
module.weight.data.fill_(1.0)
561+
module.weight.data.zero_()
561562

562563
def _setup_cache(self, config, batch, device, dtype):
563564
layers = getattr(self, "model", self).layers

src/transformers/models/t5gemma/modeling_t5gemma.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,9 @@ def _init_weights(self, module):
611611
if not self.config.tie_word_embeddings:
612612
scale = module.out_proj.weight.shape[0] ** -0.5
613613
module.out_proj.weight.data.normal_(mean=0.0, std=std * scale)
614+
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
615+
elif "RMSNorm" in module.__class__.__name__:
616+
module.weight.data.zero_()
614617

615618
def _shift_right(self, input_ids):
616619
"""

0 commit comments

Comments
 (0)