Skip to content

Commit b16688e

Browse files
authored
General weight initialization scheme (#39579)
* general + modulars from llama * all modular models * style and fix musicgen * fix * Update configuration_musicgen.py * Update modeling_utils.py
1 parent 015b62b commit b16688e

File tree

118 files changed

+205
-1566
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

118 files changed

+205
-1566
lines changed

src/transformers/modeling_utils.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2967,12 +2967,41 @@ def disable_input_require_grads(self):
29672967

29682968
def _init_weights(self, module):
29692969
"""
2970-
Initialize the weights. This method should be overridden by derived class and is
2971-
the only initialization method that will be called when loading a checkpoint
2972-
using `from_pretrained`. Any attempt to initialize outside of this function
2973-
will be useless as the torch.nn.init function are all replaced with skip.
2970+
Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
2971+
initialization scheme, it should be overriden by the derived `PreTrainedModel` class. In case a model adds an explicit
2972+
`nn.Parameter`, this method should also be overriden in order to initialize it correctly.
29742973
"""
2975-
pass
2974+
if hasattr(self.config, "initializer_range"):
2975+
std = self.config.initializer_range
2976+
else:
2977+
# 0.02 is the standard default value accross the library
2978+
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
2979+
2980+
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
2981+
module.weight.data.normal_(mean=0.0, std=std)
2982+
if module.bias is not None:
2983+
module.bias.data.zero_()
2984+
elif isinstance(module, nn.Embedding):
2985+
module.weight.data.normal_(mean=0.0, std=std)
2986+
if module.padding_idx is not None:
2987+
module.weight.data[module.padding_idx].zero_()
2988+
elif isinstance(module, nn.MultiheadAttention):
2989+
# This uses torch's original init
2990+
module._reset_parameters()
2991+
# We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
2992+
# between modelings (because they are prefixed with the model name)
2993+
elif (
2994+
isinstance(
2995+
module, (nn.LayerNorm, nn.RMSNorm, nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
2996+
)
2997+
or "LayerNorm" in module.__class__.__name__
2998+
or "RMSNorm" in module.__class__.__name__
2999+
):
3000+
# Norms can exist without weights (in which case they are None from torch primitives)
3001+
if hasattr(module, "weight") and module.weight is not None:
3002+
module.weight.data.fill_(1.0)
3003+
if hasattr(module, "bias") and module.bias is not None:
3004+
module.bias.data.zero_()
29763005

29773006
def _initialize_weights(self, module):
29783007
"""

src/transformers/models/aimv2/modeling_aimv2.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -448,24 +448,12 @@ class Aimv2PreTrainedModel(PreTrainedModel):
448448
_supports_flex_attn = True
449449

450450
def _init_weights(self, module):
451-
std = (
452-
self.config.vision_config.initializer_range
453-
if hasattr(self.config, "vision_config")
454-
else self.config.initializer_range
455-
)
456-
if isinstance(module, (nn.Linear, nn.Conv2d)):
457-
module.weight.data.normal_(mean=0.0, std=std)
458-
if module.bias is not None:
459-
module.bias.data.zero_()
460-
elif isinstance(module, Aimv2RMSNorm):
461-
module.weight.data.fill_(1.0)
462-
elif isinstance(module, nn.Embedding):
463-
module.weight.data.normal_(mean=0.0, std=std)
464-
elif hasattr(module, "logit_scale"):
451+
super()._init_weights(module)
452+
if hasattr(module, "logit_scale"):
465453
if isinstance(module.logit_scale, nn.Parameter):
466454
module.logit_scale.data.fill_(math.log(1 / 0.07))
467455
elif isinstance(module, Aimv2AttentionPoolingHead):
468-
module.cls_token.data.normal_(mean=0.0, std=std)
456+
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
469457

470458

471459
@auto_docstring(

src/transformers/models/aimv2/modular_aimv2.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -445,24 +445,12 @@ class Aimv2PreTrainedModel(PreTrainedModel):
445445
_supports_flex_attn = True
446446

447447
def _init_weights(self, module):
448-
std = (
449-
self.config.vision_config.initializer_range
450-
if hasattr(self.config, "vision_config")
451-
else self.config.initializer_range
452-
)
453-
if isinstance(module, (nn.Linear, nn.Conv2d)):
454-
module.weight.data.normal_(mean=0.0, std=std)
455-
if module.bias is not None:
456-
module.bias.data.zero_()
457-
elif isinstance(module, Aimv2RMSNorm):
458-
module.weight.data.fill_(1.0)
459-
elif isinstance(module, nn.Embedding):
460-
module.weight.data.normal_(mean=0.0, std=std)
461-
elif hasattr(module, "logit_scale"):
448+
super()._init_weights(module)
449+
if hasattr(module, "logit_scale"):
462450
if isinstance(module.logit_scale, nn.Parameter):
463451
module.logit_scale.data.fill_(math.log(1 / 0.07))
464452
elif isinstance(module, Aimv2AttentionPoolingHead):
465-
module.cls_token.data.normal_(mean=0.0, std=std)
453+
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
466454

467455

468456
@auto_docstring(

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -324,19 +324,6 @@ class ArceePreTrainedModel(PreTrainedModel):
324324
"attentions": ArceeAttention,
325325
}
326326

327-
def _init_weights(self, module):
328-
std = self.config.initializer_range
329-
if isinstance(module, nn.Linear):
330-
module.weight.data.normal_(mean=0.0, std=std)
331-
if module.bias is not None:
332-
module.bias.data.zero_()
333-
elif isinstance(module, nn.Embedding):
334-
module.weight.data.normal_(mean=0.0, std=std)
335-
if module.padding_idx is not None:
336-
module.weight.data[module.padding_idx].zero_()
337-
elif isinstance(module, ArceeRMSNorm):
338-
module.weight.data.fill_(1.0)
339-
340327

341328
@auto_docstring
342329
class ArceeModel(ArceePreTrainedModel):

src/transformers/models/aria/modeling_aria.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -639,19 +639,9 @@ class AriaTextPreTrainedModel(PreTrainedModel):
639639
}
640640

641641
def _init_weights(self, module):
642-
std = self.config.initializer_range
643-
if isinstance(module, nn.Linear):
644-
module.weight.data.normal_(mean=0.0, std=std)
645-
if module.bias is not None:
646-
module.bias.data.zero_()
647-
elif isinstance(module, nn.Embedding):
648-
module.weight.data.normal_(mean=0.0, std=std)
649-
if module.padding_idx is not None:
650-
module.weight.data[module.padding_idx].zero_()
651-
elif isinstance(module, AriaTextRMSNorm):
652-
module.weight.data.fill_(1.0)
653-
elif isinstance(module, AriaGroupedExpertsGemm):
654-
module.weight.data.normal_(mean=0.0, std=std)
642+
super()._init_weights(module)
643+
if isinstance(module, AriaGroupedExpertsGemm):
644+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
655645

656646

657647
@auto_docstring
@@ -672,20 +662,9 @@ class AriaPreTrainedModel(PreTrainedModel):
672662
}
673663

674664
def _init_weights(self, module):
675-
std = self.config.initializer_range
676-
677-
if isinstance(module, nn.Linear):
678-
module.weight.data.normal_(mean=0.0, std=std)
679-
if module.bias is not None:
680-
module.bias.data.zero_()
681-
elif isinstance(module, nn.MultiheadAttention):
682-
# This uses torch's original init
683-
module._reset_parameters()
684-
elif isinstance(module, nn.LayerNorm):
685-
module.weight.data.fill_(1.0)
686-
module.bias.data.zero_()
687-
elif isinstance(module, AriaProjector):
688-
nn.init.trunc_normal_(module.query, std=std)
665+
super()._init_weights(module)
666+
if isinstance(module, AriaProjector):
667+
nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
689668

690669

691670
class AriaTextRotaryEmbedding(nn.Module):

src/transformers/models/aria/modular_aria.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,19 +1294,9 @@ class AriaTextPreTrainedModel(PreTrainedModel):
12941294
}
12951295

12961296
def _init_weights(self, module):
1297-
std = self.config.initializer_range
1298-
if isinstance(module, nn.Linear):
1299-
module.weight.data.normal_(mean=0.0, std=std)
1300-
if module.bias is not None:
1301-
module.bias.data.zero_()
1302-
elif isinstance(module, nn.Embedding):
1303-
module.weight.data.normal_(mean=0.0, std=std)
1304-
if module.padding_idx is not None:
1305-
module.weight.data[module.padding_idx].zero_()
1306-
elif isinstance(module, AriaTextRMSNorm):
1307-
module.weight.data.fill_(1.0)
1308-
elif isinstance(module, AriaGroupedExpertsGemm):
1309-
module.weight.data.normal_(mean=0.0, std=std)
1297+
super()._init_weights(module)
1298+
if isinstance(module, AriaGroupedExpertsGemm):
1299+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
13101300

13111301

13121302
class AriaPreTrainedModel(LlamaPreTrainedModel):
@@ -1316,20 +1306,9 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
13161306
_supports_attention_backend = True
13171307

13181308
def _init_weights(self, module):
1319-
std = self.config.initializer_range
1320-
1321-
if isinstance(module, nn.Linear):
1322-
module.weight.data.normal_(mean=0.0, std=std)
1323-
if module.bias is not None:
1324-
module.bias.data.zero_()
1325-
elif isinstance(module, nn.MultiheadAttention):
1326-
# This uses torch's original init
1327-
module._reset_parameters()
1328-
elif isinstance(module, nn.LayerNorm):
1329-
module.weight.data.fill_(1.0)
1330-
module.bias.data.zero_()
1331-
elif isinstance(module, AriaProjector):
1332-
nn.init.trunc_normal_(module.query, std=std)
1309+
LlamaPreTrainedModel._init_weights(module)
1310+
if isinstance(module, AriaProjector):
1311+
nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
13331312

13341313

13351314
class AriaTextModel(LlamaModel):

src/transformers/models/aya_vision/modeling_aya_vision.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,21 +100,6 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
100100
_supports_flex_attn = True
101101
_supports_attention_backend = True
102102

103-
def _init_weights(self, module):
104-
std = (
105-
self.config.initializer_range
106-
if hasattr(self.config, "initializer_range")
107-
else self.config.text_config.initializer_range
108-
)
109-
110-
if isinstance(module, nn.Linear):
111-
module.weight.data.normal_(mean=0.0, std=std)
112-
if module.bias is not None:
113-
module.bias.data.zero_()
114-
elif isinstance(module, nn.LayerNorm):
115-
module.weight.data.fill_(1.0)
116-
module.bias.data.zero_()
117-
118103

119104
@dataclass
120105
@auto_docstring(

src/transformers/models/aya_vision/modular_aya_vision.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,6 @@ def pixel_shuffle(self, image_features): # B, S, D
9292
class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
9393
_supports_static_cache = False
9494

95-
def _init_weights(self, module):
96-
std = (
97-
self.config.initializer_range
98-
if hasattr(self.config, "initializer_range")
99-
else self.config.text_config.initializer_range
100-
)
101-
102-
if isinstance(module, nn.Linear):
103-
module.weight.data.normal_(mean=0.0, std=std)
104-
if module.bias is not None:
105-
module.bias.data.zero_()
106-
elif isinstance(module, nn.LayerNorm):
107-
module.weight.data.fill_(1.0)
108-
module.bias.data.zero_()
109-
11095

11196
class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
11297
pass

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,18 +1088,8 @@ class BambaPreTrainedModel(PreTrainedModel):
10881088
_is_stateful = True
10891089

10901090
def _init_weights(self, module):
1091-
std = self.config.initializer_range
1092-
if isinstance(module, (nn.Linear, nn.Conv1d)):
1093-
module.weight.data.normal_(mean=0.0, std=std)
1094-
if module.bias is not None:
1095-
module.bias.data.zero_()
1096-
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
1097-
module.weight.data.fill_(1.0)
1098-
elif isinstance(module, nn.Embedding):
1099-
module.weight.data.normal_(mean=0.0, std=std)
1100-
if module.padding_idx is not None:
1101-
module.weight.data[module.padding_idx].zero_()
1102-
elif isinstance(module, BambaMixer):
1091+
super()._init_weights(module)
1092+
if isinstance(module, BambaMixer):
11031093
module.dt_bias.data.fill_(1.0)
11041094
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
11051095
module.D.data.fill_(1.0)

src/transformers/models/bamba/modular_bamba.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -816,18 +816,8 @@ class BambaPreTrainedModel(PreTrainedModel):
816816
_is_stateful = True
817817

818818
def _init_weights(self, module):
819-
std = self.config.initializer_range
820-
if isinstance(module, (nn.Linear, nn.Conv1d)):
821-
module.weight.data.normal_(mean=0.0, std=std)
822-
if module.bias is not None:
823-
module.bias.data.zero_()
824-
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
825-
module.weight.data.fill_(1.0)
826-
elif isinstance(module, nn.Embedding):
827-
module.weight.data.normal_(mean=0.0, std=std)
828-
if module.padding_idx is not None:
829-
module.weight.data[module.padding_idx].zero_()
830-
elif isinstance(module, BambaMixer):
819+
super()._init_weights(module)
820+
if isinstance(module, BambaMixer):
831821
module.dt_bias.data.fill_(1.0)
832822
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
833823
module.D.data.fill_(1.0)

0 commit comments

Comments
 (0)