Skip to content

Commit e8f9cc5

Browse files
WoosukKwonjimpang
authored andcommitted
Use Llama RMSNorm custom op for Gemma (vllm-project#2974)
1 parent 031fd41 commit e8f9cc5

File tree

1 file changed

+27
-33
lines changed

1 file changed

+27
-33
lines changed

vllm/model_executor/models/gemma.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from vllm.model_executor.input_metadata import InputMetadata
2424
from vllm.model_executor.layers.attention import PagedAttention
25+
from vllm.model_executor.layers.layernorm import RMSNorm
2526
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2627
LinearMethodBase,
2728
QKVParallelLinear,
@@ -40,21 +41,6 @@
4041
KVCache = Tuple[torch.Tensor, torch.Tensor]
4142

4243

43-
class GemmaRMSNorm(nn.Module):
44-
45-
def __init__(self, dim: int, eps: float = 1e-6):
46-
super().__init__()
47-
self.eps = eps
48-
self.weight = nn.Parameter(torch.zeros(dim))
49-
50-
def _norm(self, x):
51-
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
52-
53-
def forward(self, x):
54-
output = self._norm(x.float()).type_as(x)
55-
return output * (1 + self.weight)
56-
57-
5844
class GemmaMLP(nn.Module):
5945

6046
def __init__(
@@ -185,36 +171,38 @@ def __init__(
185171
intermediate_size=config.intermediate_size,
186172
linear_method=linear_method,
187173
)
188-
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
189-
eps=config.rms_norm_eps)
190-
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
191-
eps=config.rms_norm_eps)
174+
self.input_layernorm = RMSNorm(config.hidden_size,
175+
eps=config.rms_norm_eps)
176+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
177+
eps=config.rms_norm_eps)
192178

193179
def forward(
194180
self,
195181
positions: torch.Tensor,
196182
hidden_states: torch.Tensor,
197183
kv_cache: KVCache,
198184
input_metadata: InputMetadata,
185+
residual: Optional[torch.Tensor],
199186
) -> Tuple[torch.Tensor, torch.Tensor]:
200187
# Self Attention
201-
residual = hidden_states
202-
hidden_states = self.input_layernorm(hidden_states)
188+
if residual is None:
189+
residual = hidden_states
190+
hidden_states = self.input_layernorm(hidden_states)
191+
else:
192+
hidden_states, residual = self.input_layernorm(
193+
hidden_states, residual)
203194
hidden_states = self.self_attn(
204195
positions=positions,
205196
hidden_states=hidden_states,
206197
kv_cache=kv_cache,
207198
input_metadata=input_metadata,
208199
)
209-
hidden_states = residual + hidden_states
210200

211201
# Fully Connected
212-
residual = hidden_states
213-
hidden_states = self.post_attention_layernorm(hidden_states)
202+
hidden_states, residual = self.post_attention_layernorm(
203+
hidden_states, residual)
214204
hidden_states = self.mlp(hidden_states)
215-
hidden_states = residual + hidden_states
216-
217-
return hidden_states
205+
return hidden_states, residual
218206

219207

220208
class GemmaModel(nn.Module):
@@ -235,7 +223,7 @@ def __init__(
235223
GemmaDecoderLayer(config, linear_method)
236224
for _ in range(config.num_hidden_layers)
237225
])
238-
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
226+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239227

240228
def forward(
241229
self,
@@ -246,17 +234,19 @@ def forward(
246234
) -> torch.Tensor:
247235
hidden_states = self.embed_tokens(input_ids)
248236
# Normalize the embedding by sqrt(hidden_size)
249-
hidden_states = hidden_states * (self.config.hidden_size**0.5)
237+
hidden_states *= self.config.hidden_size**0.5
250238

239+
residual = None
251240
for i in range(len(self.layers)):
252241
layer = self.layers[i]
253-
hidden_states = layer(
242+
hidden_states, residual = layer(
254243
positions,
255244
hidden_states,
256245
kv_caches[i],
257246
input_metadata,
247+
residual,
258248
)
259-
hidden_states = self.norm(hidden_states)
249+
hidden_states, _ = self.norm(hidden_states, residual)
260250
return hidden_states
261251

262252

@@ -321,6 +311,10 @@ def load_weights(self,
321311
# Skip loading extra layer for lora models.
322312
if "lm_head" in name:
323313
continue
314+
# GemmaRMSNorm is different from Llama's in that it multiplies
315+
# (1 + weight) to the output, instead of just weight.
316+
if "norm.weight" in name:
317+
loaded_weight += 1.0
324318
param = params_dict[name]
325319
weight_loader = getattr(param, "weight_loader",
326320
default_weight_loader)
@@ -329,5 +323,5 @@ def load_weights(self,
329323
unloaded_params = params_dict.keys() - loaded_params
330324
if unloaded_params:
331325
raise RuntimeError(
332-
f"Some weights are not initialized from checkpoints: {unloaded_params}"
333-
)
326+
"Some weights are not initialized from checkpoints: "
327+
f"{unloaded_params}")

0 commit comments

Comments
 (0)