@@ -83,6 +83,42 @@ def forward(self, x):
8383 return down_proj
8484
8585
86+ class Gemma2RotaryEmbedding (nn .Module ):
87+ inv_freq : torch .Tensor # fix linting for `register_buffer`
88+
89+ def __init__ (self , config : Gemma2Config , device = None ):
90+ super ().__init__ ()
91+ # BC: "rope_type" was originally "type"
92+ if hasattr (config , "rope_scaling" ) and isinstance (config .rope_scaling , dict ):
93+ self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
94+ else :
95+ self .rope_type = "default"
96+ self .max_seq_len_cached = config .max_position_embeddings
97+ self .original_max_seq_len = config .max_position_embeddings
98+
99+ self .config = config
100+ self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
101+
102+ inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device )
103+ self .register_buffer ("inv_freq" , inv_freq , persistent = False )
104+ self .original_inv_freq = self .inv_freq
105+
106+ @torch .no_grad ()
107+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
108+ def forward (self , x , position_ids ):
109+ inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
110+ position_ids_expanded = position_ids [:, None , :].float ()
111+
112+ device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
113+ with torch .autocast (device_type = device_type , enabled = False ): # Force float32
114+ freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
115+ emb = torch .cat ((freqs , freqs ), dim = - 1 )
116+ cos = emb .cos () * self .attention_scaling
117+ sin = emb .sin () * self .attention_scaling
118+
119+ return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
120+
121+
86122def rotate_half (x ):
87123 """Rotates half the hidden dims of the input."""
88124 x1 = x [..., : x .shape [- 1 ] // 2 ]
@@ -299,42 +335,6 @@ def forward(
299335 return outputs
300336
301337
302- class Gemma2RotaryEmbedding (nn .Module ):
303- inv_freq : torch .Tensor # fix linting for `register_buffer`
304-
305- def __init__ (self , config : Gemma2Config , device = None ):
306- super ().__init__ ()
307- # BC: "rope_type" was originally "type"
308- if hasattr (config , "rope_scaling" ) and isinstance (config .rope_scaling , dict ):
309- self .rope_type = config .rope_scaling .get ("rope_type" , config .rope_scaling .get ("type" ))
310- else :
311- self .rope_type = "default"
312- self .max_seq_len_cached = config .max_position_embeddings
313- self .original_max_seq_len = config .max_position_embeddings
314-
315- self .config = config
316- self .rope_init_fn = ROPE_INIT_FUNCTIONS [self .rope_type ]
317-
318- inv_freq , self .attention_scaling = self .rope_init_fn (self .config , device )
319- self .register_buffer ("inv_freq" , inv_freq , persistent = False )
320- self .original_inv_freq = self .inv_freq
321-
322- @torch .no_grad ()
323- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
324- def forward (self , x , position_ids ):
325- inv_freq_expanded = self .inv_freq [None , :, None ].float ().expand (position_ids .shape [0 ], - 1 , 1 ).to (x .device )
326- position_ids_expanded = position_ids [:, None , :].float ()
327-
328- device_type = x .device .type if isinstance (x .device .type , str ) and x .device .type != "mps" else "cpu"
329- with torch .autocast (device_type = device_type , enabled = False ): # Force float32
330- freqs = (inv_freq_expanded .float () @ position_ids_expanded .float ()).transpose (1 , 2 )
331- emb = torch .cat ((freqs , freqs ), dim = - 1 )
332- cos = emb .cos () * self .attention_scaling
333- sin = emb .sin () * self .attention_scaling
334-
335- return cos .to (dtype = x .dtype ), sin .to (dtype = x .dtype )
336-
337-
338338@auto_docstring
339339class Gemma2PreTrainedModel (PreTrainedModel ):
340340 config : Gemma2Config
@@ -353,6 +353,13 @@ class Gemma2PreTrainedModel(PreTrainedModel):
353353 "attentions" : Gemma2Attention ,
354354 }
355355
356+ def _init_weights (self , module ):
357+ super ()._init_weights (module )
358+
359+ # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
360+ if "RMSNorm" in module .__class__ .__name__ :
361+ module .weight .data .zero_ ()
362+
356363
357364@auto_docstring
358365class Gemma2Model (Gemma2PreTrainedModel ):
0 commit comments