Skip to content

Commit 5253eda

Browse files
Add Gemma model (#2964)
1 parent 017d9f1 commit 5253eda

File tree

2 files changed

+334
-0
lines changed

2 files changed

+334
-0
lines changed

vllm/model_executor/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
2121
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
2222
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
23+
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
2324
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
2425
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
2526
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# coding=utf-8
2+
# Copyright 2023 The vLLM team.
3+
# Copyright (c) Google Inc.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""Inference-only Gemma model compatible with HuggingFace weights."""
17+
from typing import List, Optional, Tuple
18+
19+
import torch
20+
from torch import nn
21+
from transformers import GemmaConfig
22+
23+
from vllm.model_executor.input_metadata import InputMetadata
24+
from vllm.model_executor.layers.attention import PagedAttention
25+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
26+
LinearMethodBase,
27+
QKVParallelLinear,
28+
RowParallelLinear)
29+
from vllm.model_executor.layers.rotary_embedding import get_rope
30+
from vllm.model_executor.layers.sampler import Sampler
31+
from vllm.model_executor.layers.vocab_parallel_embedding import (
32+
VocabParallelEmbedding)
33+
from vllm.model_executor.parallel_utils.parallel_state import (
34+
get_tensor_model_parallel_world_size)
35+
from vllm.model_executor.sampling_metadata import SamplingMetadata
36+
from vllm.model_executor.weight_utils import (default_weight_loader,
37+
hf_model_weights_iterator)
38+
from vllm.sequence import SamplerOutput
39+
40+
KVCache = Tuple[torch.Tensor, torch.Tensor]
41+
42+
43+
class GemmaRMSNorm(nn.Module):
44+
45+
def __init__(self, dim: int, eps: float = 1e-6):
46+
super().__init__()
47+
self.eps = eps
48+
self.weight = nn.Parameter(torch.zeros(dim))
49+
50+
def _norm(self, x):
51+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
52+
53+
def forward(self, x):
54+
output = self._norm(x.float()).type_as(x)
55+
return output * (1 + self.weight)
56+
57+
58+
class GemmaMLP(nn.Module):
59+
60+
def __init__(
61+
self,
62+
hidden_size: int,
63+
intermediate_size: int,
64+
linear_method: Optional[LinearMethodBase] = None,
65+
) -> None:
66+
super().__init__()
67+
self.gate_proj = ColumnParallelLinear(hidden_size,
68+
intermediate_size,
69+
bias=False,
70+
linear_method=linear_method)
71+
self.up_proj = ColumnParallelLinear(hidden_size,
72+
intermediate_size,
73+
bias=False,
74+
linear_method=linear_method)
75+
self.down_proj = RowParallelLinear(intermediate_size,
76+
hidden_size,
77+
bias=False,
78+
linear_method=linear_method)
79+
self.act_fn = nn.GELU()
80+
81+
def forward(self, x):
82+
gate, _ = self.gate_proj(x)
83+
gate = self.act_fn(gate)
84+
up, _ = self.up_proj(x)
85+
fuse = gate * up
86+
outputs, _ = self.down_proj(fuse)
87+
return outputs
88+
89+
90+
class GemmaAttention(nn.Module):
91+
92+
def __init__(self,
93+
hidden_size: int,
94+
num_heads: int,
95+
num_kv_heads: int,
96+
head_dim: int,
97+
max_position_embeddings: int = 8192,
98+
rope_theta: float = 10000,
99+
linear_method: Optional[LinearMethodBase] = None) -> None:
100+
super().__init__()
101+
self.hidden_size = hidden_size
102+
tp_size = get_tensor_model_parallel_world_size()
103+
self.total_num_heads = num_heads
104+
assert self.total_num_heads % tp_size == 0
105+
self.num_heads = self.total_num_heads // tp_size
106+
self.total_num_kv_heads = num_kv_heads
107+
if self.total_num_kv_heads >= tp_size:
108+
# Number of KV heads is greater than TP size, so we partition
109+
# the KV heads across multiple tensor parallel GPUs.
110+
assert self.total_num_kv_heads % tp_size == 0
111+
else:
112+
# Number of KV heads is less than TP size, so we replicate
113+
# the KV heads across multiple tensor parallel GPUs.
114+
assert tp_size % self.total_num_kv_heads == 0
115+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
116+
self.head_dim = head_dim
117+
self.q_size = self.num_heads * self.head_dim
118+
self.kv_size = self.num_kv_heads * self.head_dim
119+
self.scaling = self.head_dim**-0.5
120+
self.rope_theta = rope_theta
121+
122+
self.qkv_proj = QKVParallelLinear(
123+
hidden_size,
124+
self.head_dim,
125+
self.total_num_heads,
126+
self.total_num_kv_heads,
127+
bias=False,
128+
linear_method=linear_method,
129+
)
130+
self.o_proj = RowParallelLinear(
131+
self.total_num_heads * self.head_dim,
132+
hidden_size,
133+
bias=False,
134+
linear_method=linear_method,
135+
)
136+
137+
self.rotary_emb = get_rope(
138+
self.head_dim,
139+
rotary_dim=self.head_dim,
140+
max_position=max_position_embeddings,
141+
base=self.rope_theta,
142+
is_neox_style=True,
143+
)
144+
self.attn = PagedAttention(self.num_heads,
145+
self.head_dim,
146+
self.scaling,
147+
num_kv_heads=self.num_kv_heads)
148+
149+
def forward(
150+
self,
151+
positions: torch.Tensor,
152+
hidden_states: torch.Tensor,
153+
kv_cache: KVCache,
154+
input_metadata: InputMetadata,
155+
) -> torch.Tensor:
156+
qkv, _ = self.qkv_proj(hidden_states)
157+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
158+
q, k = self.rotary_emb(positions, q, k)
159+
k_cache, v_cache = kv_cache
160+
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
161+
output, _ = self.o_proj(attn_output)
162+
return output
163+
164+
165+
class GemmaDecoderLayer(nn.Module):
166+
167+
def __init__(
168+
self,
169+
config: GemmaConfig,
170+
linear_method: Optional[LinearMethodBase] = None,
171+
) -> None:
172+
super().__init__()
173+
self.hidden_size = config.hidden_size
174+
self.self_attn = GemmaAttention(
175+
hidden_size=self.hidden_size,
176+
num_heads=config.num_attention_heads,
177+
num_kv_heads=config.num_key_value_heads,
178+
head_dim=config.head_dim,
179+
max_position_embeddings=config.max_position_embeddings,
180+
rope_theta=config.rope_theta,
181+
linear_method=linear_method,
182+
)
183+
self.mlp = GemmaMLP(
184+
hidden_size=self.hidden_size,
185+
intermediate_size=config.intermediate_size,
186+
linear_method=linear_method,
187+
)
188+
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
189+
eps=config.rms_norm_eps)
190+
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
191+
eps=config.rms_norm_eps)
192+
193+
def forward(
194+
self,
195+
positions: torch.Tensor,
196+
hidden_states: torch.Tensor,
197+
kv_cache: KVCache,
198+
input_metadata: InputMetadata,
199+
) -> Tuple[torch.Tensor, torch.Tensor]:
200+
# Self Attention
201+
residual = hidden_states
202+
hidden_states = self.input_layernorm(hidden_states)
203+
hidden_states = self.self_attn(
204+
positions=positions,
205+
hidden_states=hidden_states,
206+
kv_cache=kv_cache,
207+
input_metadata=input_metadata,
208+
)
209+
hidden_states = residual + hidden_states
210+
211+
# Fully Connected
212+
residual = hidden_states
213+
hidden_states = self.post_attention_layernorm(hidden_states)
214+
hidden_states = self.mlp(hidden_states)
215+
hidden_states = residual + hidden_states
216+
217+
return hidden_states
218+
219+
220+
class GemmaModel(nn.Module):
221+
222+
def __init__(
223+
self,
224+
config: GemmaConfig,
225+
linear_method: Optional[LinearMethodBase] = None,
226+
) -> None:
227+
super().__init__()
228+
self.config = config
229+
230+
self.embed_tokens = VocabParallelEmbedding(
231+
config.vocab_size,
232+
config.hidden_size,
233+
)
234+
self.layers = nn.ModuleList([
235+
GemmaDecoderLayer(config, linear_method)
236+
for _ in range(config.num_hidden_layers)
237+
])
238+
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239+
240+
def forward(
241+
self,
242+
input_ids: torch.Tensor,
243+
positions: torch.Tensor,
244+
kv_caches: List[KVCache],
245+
input_metadata: InputMetadata,
246+
) -> torch.Tensor:
247+
hidden_states = self.embed_tokens(input_ids)
248+
# Normalize the embedding by sqrt(hidden_size)
249+
hidden_states = hidden_states * (self.config.hidden_size**0.5)
250+
251+
for i in range(len(self.layers)):
252+
layer = self.layers[i]
253+
hidden_states = layer(
254+
positions,
255+
hidden_states,
256+
kv_caches[i],
257+
input_metadata,
258+
)
259+
hidden_states = self.norm(hidden_states)
260+
return hidden_states
261+
262+
263+
class GemmaForCausalLM(nn.Module):
264+
265+
def __init__(
266+
self,
267+
config: GemmaConfig,
268+
linear_method: Optional[LinearMethodBase] = None,
269+
) -> None:
270+
super().__init__()
271+
self.config = config
272+
self.linear_method = linear_method
273+
self.model = GemmaModel(config, linear_method)
274+
self.sampler = Sampler(config.vocab_size)
275+
276+
@torch.no_grad()
277+
def forward(
278+
self,
279+
input_ids: torch.Tensor,
280+
positions: torch.Tensor,
281+
kv_caches: List[KVCache],
282+
input_metadata: InputMetadata,
283+
) -> torch.Tensor:
284+
hidden_states = self.model(input_ids, positions, kv_caches,
285+
input_metadata)
286+
return hidden_states
287+
288+
def sample(
289+
self,
290+
hidden_states: torch.Tensor,
291+
sampling_metadata: SamplingMetadata,
292+
) -> Optional[SamplerOutput]:
293+
next_tokens = self.sampler(self.model.embed_tokens.weight,
294+
hidden_states, sampling_metadata)
295+
return next_tokens
296+
297+
def load_weights(self,
298+
model_name_or_path: str,
299+
cache_dir: Optional[str] = None,
300+
load_format: str = "auto",
301+
revision: Optional[str] = None):
302+
stacked_params_mapping = [
303+
# (param_name, shard_name, shard_id)
304+
("qkv_proj", "q_proj", "q"),
305+
("qkv_proj", "k_proj", "k"),
306+
("qkv_proj", "v_proj", "v"),
307+
]
308+
params_dict = dict(self.named_parameters())
309+
loaded_params = set()
310+
for name, loaded_weight in hf_model_weights_iterator(
311+
model_name_or_path, cache_dir, load_format, revision):
312+
for (param_name, shard_name, shard_id) in stacked_params_mapping:
313+
if shard_name not in name:
314+
continue
315+
name = name.replace(shard_name, param_name)
316+
param = params_dict[name]
317+
weight_loader = param.weight_loader
318+
weight_loader(param, loaded_weight, shard_id)
319+
break
320+
else:
321+
# Skip loading extra layer for lora models.
322+
if "lm_head" in name:
323+
continue
324+
param = params_dict[name]
325+
weight_loader = getattr(param, "weight_loader",
326+
default_weight_loader)
327+
weight_loader(param, loaded_weight)
328+
loaded_params.add(name)
329+
unloaded_params = params_dict.keys() - loaded_params
330+
if unloaded_params:
331+
raise RuntimeError(
332+
f"Some weights are not initialized from checkpoints: {unloaded_params}"
333+
)

0 commit comments

Comments
 (0)