@@ -2967,12 +2967,41 @@ def disable_input_require_grads(self):
2967
2967
2968
2968
def _init_weights (self , module ):
2969
2969
"""
2970
- Initialize the weights. This method should be overridden by derived class and is
2971
- the only initialization method that will be called when loading a checkpoint
2972
- using `from_pretrained`. Any attempt to initialize outside of this function
2973
- will be useless as the torch.nn.init function are all replaced with skip.
2970
+ Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
2971
+ initialization scheme, it should be overriden by the derived `PreTrainedModel` class. In case a model adds an explicit
2972
+ `nn.Parameter`, this method should also be overriden in order to initialize it correctly.
2974
2973
"""
2975
- pass
2974
+ if hasattr (self .config , "initializer_range" ):
2975
+ std = self .config .initializer_range
2976
+ else :
2977
+ # 0.02 is the standard default value accross the library
2978
+ std = getattr (self .config .get_text_config (), "initializer_range" , 0.02 )
2979
+
2980
+ if isinstance (module , (nn .Linear , nn .Conv1d , nn .Conv2d , nn .Conv3d , nn .ConvTranspose1d , nn .ConvTranspose2d )):
2981
+ module .weight .data .normal_ (mean = 0.0 , std = std )
2982
+ if module .bias is not None :
2983
+ module .bias .data .zero_ ()
2984
+ elif isinstance (module , nn .Embedding ):
2985
+ module .weight .data .normal_ (mean = 0.0 , std = std )
2986
+ if module .padding_idx is not None :
2987
+ module .weight .data [module .padding_idx ].zero_ ()
2988
+ elif isinstance (module , nn .MultiheadAttention ):
2989
+ # This uses torch's original init
2990
+ module ._reset_parameters ()
2991
+ # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
2992
+ # between modelings (because they are prefixed with the model name)
2993
+ elif (
2994
+ isinstance (
2995
+ module , (nn .LayerNorm , nn .RMSNorm , nn .GroupNorm , nn .BatchNorm1d , nn .BatchNorm2d , nn .BatchNorm3d )
2996
+ )
2997
+ or "LayerNorm" in module .__class__ .__name__
2998
+ or "RMSNorm" in module .__class__ .__name__
2999
+ ):
3000
+ # Norms can exist without weights (in which case they are None from torch primitives)
3001
+ if hasattr (module , "weight" ) and module .weight is not None :
3002
+ module .weight .data .fill_ (1.0 )
3003
+ if hasattr (module , "bias" ) and module .bias is not None :
3004
+ module .bias .data .zero_ ()
2976
3005
2977
3006
def _initialize_weights (self , module ):
2978
3007
"""
0 commit comments