Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,13 @@ class GemmaPreTrainedModel(PreTrainedModel):
"attentions": GemmaAttention,
}

def _init_weights(self, module):
super()._init_weights(module)

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()


@auto_docstring
class GemmaModel(GemmaPreTrainedModel):
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ...configuration_utils import PretrainedConfig
from ...masking_utils import create_causal_mask
from ...modeling_outputs import BaseModelOutputWithPast
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...tokenization_utils import AddedToken, PreTrainedTokenizer
from ...utils import TransformersKwargs, logging
Expand All @@ -32,6 +33,8 @@
LlamaForTokenClassification,
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
)
from ..llama.tokenization_llama import LlamaTokenizer

Expand Down Expand Up @@ -361,6 +364,19 @@ def __init__(self, config):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)


class GemmaRotaryEmbedding(LlamaRotaryEmbedding):
pass


class GemmaPreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()


class GemmaModel(LlamaModel):
def forward(
self,
Expand Down
79 changes: 43 additions & 36 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,42 @@ def forward(self, x):
return down_proj


class Gemma2RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`

def __init__(self, config: Gemma2Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -299,42 +335,6 @@ def forward(
return outputs


class Gemma2RotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`

def __init__(self, config: Gemma2Config, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


@auto_docstring
class Gemma2PreTrainedModel(PreTrainedModel):
config: Gemma2Config
Expand All @@ -353,6 +353,13 @@ class Gemma2PreTrainedModel(PreTrainedModel):
"attentions": Gemma2Attention,
}

def _init_weights(self, module):
super()._init_weights(module)

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()


@auto_docstring
class Gemma2Model(Gemma2PreTrainedModel):
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
GemmaForTokenClassification,
GemmaMLP,
GemmaModel,
GemmaPreTrainedModel,
GemmaRMSNorm,
GemmaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
Expand Down Expand Up @@ -212,6 +214,10 @@ def __init__(self, config):
self.act_fn = ACT2FN[config.hidden_activation]


class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
pass


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
Expand Down Expand Up @@ -363,6 +369,10 @@ def forward(
return outputs


class Gemma2PreTrainedModel(GemmaPreTrainedModel):
pass


class Gemma2Model(GemmaModel):
def __init__(self, config: Gemma2Config):
super().__init__(config)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, Gemma3MultiModalProjector):
module.mm_input_projection_weight.data.zero_()
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
elif "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()


def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,9 @@ def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)
if isinstance(module, Gemma3MultiModalProjector):
module.mm_input_projection_weight.data.zero_()
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
elif "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()


def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/qwen3_next/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ def _init_weights(self, module):
if isinstance(module, Qwen3NextGatedDeltaNet):
module.dt_bias.data.fill_(1.0)
module.A_log.data.uniform_(0, 16).log_()
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
elif isinstance(module, Qwen3NextRMSNorm):
module.weight.data.zero_()


class Qwen3NextModel(Qwen3NextPreTrainedModel):
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/qwen3_next/modular_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,9 @@ def _init_weights(self, module):
if isinstance(module, Qwen3NextGatedDeltaNet):
module.dt_bias.data.fill_(1.0)
module.A_log.data.uniform_(0, 16).log_()
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
elif isinstance(module, Qwen3NextRMSNorm):
module.weight.data.zero_()


class Qwen3NextModel(Qwen3NextPreTrainedModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,9 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
elif isinstance(module, RecurrentGemmaRMSNorm):
module.weight.data.fill_(1.0)
module.weight.data.zero_()

def _setup_cache(self, config, batch, device, dtype):
layers = getattr(self, "model", self).layers
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/t5gemma/modeling_t5gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,9 @@ def _init_weights(self, module):
if not self.config.tie_word_embeddings:
scale = module.out_proj.weight.shape[0] ** -0.5
module.out_proj.weight.data.normal_(mean=0.0, std=std * scale)
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
elif "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()

def _shift_right(self, input_ids):
"""
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/t5gemma/modular_t5gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,9 @@ def _init_weights(self, module):
if not self.config.tie_word_embeddings:
scale = module.out_proj.weight.shape[0] ** -0.5
module.out_proj.weight.data.normal_(mean=0.0, std=std * scale)
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
elif "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()

def _shift_right(self, input_ids):
"""
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/vaultgemma/modeling_vaultgemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ class VaultGemmaPreTrainedModel(PreTrainedModel):
"attentions": VaultGemmaAttention,
}

def _init_weights(self, module):
super()._init_weights(module)

# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()


@auto_docstring
class VaultGemmaModel(VaultGemmaPreTrainedModel):
Expand Down