2222
2323from vllm .model_executor .input_metadata import InputMetadata
2424from vllm .model_executor .layers .attention import PagedAttention
25+ from vllm .model_executor .layers .layernorm import RMSNorm
2526from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
2627 LinearMethodBase ,
2728 QKVParallelLinear ,
4041KVCache = Tuple [torch .Tensor , torch .Tensor ]
4142
4243
43- class GemmaRMSNorm (nn .Module ):
44-
45- def __init__ (self , dim : int , eps : float = 1e-6 ):
46- super ().__init__ ()
47- self .eps = eps
48- self .weight = nn .Parameter (torch .zeros (dim ))
49-
50- def _norm (self , x ):
51- return x * torch .rsqrt (x .pow (2 ).mean (- 1 , keepdim = True ) + self .eps )
52-
53- def forward (self , x ):
54- output = self ._norm (x .float ()).type_as (x )
55- return output * (1 + self .weight )
56-
57-
5844class GemmaMLP (nn .Module ):
5945
6046 def __init__ (
@@ -185,36 +171,38 @@ def __init__(
185171 intermediate_size = config .intermediate_size ,
186172 linear_method = linear_method ,
187173 )
188- self .input_layernorm = GemmaRMSNorm (config .hidden_size ,
189- eps = config .rms_norm_eps )
190- self .post_attention_layernorm = GemmaRMSNorm (config .hidden_size ,
191- eps = config .rms_norm_eps )
174+ self .input_layernorm = RMSNorm (config .hidden_size ,
175+ eps = config .rms_norm_eps )
176+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
177+ eps = config .rms_norm_eps )
192178
193179 def forward (
194180 self ,
195181 positions : torch .Tensor ,
196182 hidden_states : torch .Tensor ,
197183 kv_cache : KVCache ,
198184 input_metadata : InputMetadata ,
185+ residual : Optional [torch .Tensor ],
199186 ) -> Tuple [torch .Tensor , torch .Tensor ]:
200187 # Self Attention
201- residual = hidden_states
202- hidden_states = self .input_layernorm (hidden_states )
188+ if residual is None :
189+ residual = hidden_states
190+ hidden_states = self .input_layernorm (hidden_states )
191+ else :
192+ hidden_states , residual = self .input_layernorm (
193+ hidden_states , residual )
203194 hidden_states = self .self_attn (
204195 positions = positions ,
205196 hidden_states = hidden_states ,
206197 kv_cache = kv_cache ,
207198 input_metadata = input_metadata ,
208199 )
209- hidden_states = residual + hidden_states
210200
211201 # Fully Connected
212- residual = hidden_states
213- hidden_states = self . post_attention_layernorm ( hidden_states )
202+ hidden_states , residual = self . post_attention_layernorm (
203+ hidden_states , residual )
214204 hidden_states = self .mlp (hidden_states )
215- hidden_states = residual + hidden_states
216-
217- return hidden_states
205+ return hidden_states , residual
218206
219207
220208class GemmaModel (nn .Module ):
@@ -235,7 +223,7 @@ def __init__(
235223 GemmaDecoderLayer (config , linear_method )
236224 for _ in range (config .num_hidden_layers )
237225 ])
238- self .norm = GemmaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
226+ self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
239227
240228 def forward (
241229 self ,
@@ -246,17 +234,19 @@ def forward(
246234 ) -> torch .Tensor :
247235 hidden_states = self .embed_tokens (input_ids )
248236 # Normalize the embedding by sqrt(hidden_size)
249- hidden_states = hidden_states * ( self .config .hidden_size ** 0.5 )
237+ hidden_states *= self .config .hidden_size ** 0.5
250238
239+ residual = None
251240 for i in range (len (self .layers )):
252241 layer = self .layers [i ]
253- hidden_states = layer (
242+ hidden_states , residual = layer (
254243 positions ,
255244 hidden_states ,
256245 kv_caches [i ],
257246 input_metadata ,
247+ residual ,
258248 )
259- hidden_states = self .norm (hidden_states )
249+ hidden_states , _ = self .norm (hidden_states , residual )
260250 return hidden_states
261251
262252
@@ -321,6 +311,10 @@ def load_weights(self,
321311 # Skip loading extra layer for lora models.
322312 if "lm_head" in name :
323313 continue
314+ # GemmaRMSNorm is different from Llama's in that it multiplies
315+ # (1 + weight) to the output, instead of just weight.
316+ if "norm.weight" in name :
317+ loaded_weight += 1.0
324318 param = params_dict [name ]
325319 weight_loader = getattr (param , "weight_loader" ,
326320 default_weight_loader )
@@ -329,5 +323,5 @@ def load_weights(self,
329323 unloaded_params = params_dict .keys () - loaded_params
330324 if unloaded_params :
331325 raise RuntimeError (
332- f "Some weights are not initialized from checkpoints: { unloaded_params } "
333- )
326+ "Some weights are not initialized from checkpoints: "
327+ f" { unloaded_params } " )
0 commit comments