Skip to content

Commit 179a6a3

Browse files
[Model]Refactor MiniCPMV (vllm-project#7020)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 83c644f commit 179a6a3

File tree

4 files changed

+937
-386
lines changed

4 files changed

+937
-386
lines changed

docs/source/models/supported_models.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ Vision Language Models
220220
- Phi-3-Vision
221221
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
222222
-
223-
* - :code:`MiniCPM-V`
223+
* - :code:`MiniCPMV`
224224
- MiniCPM-V
225225
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
226226
-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# coding=utf-8
2+
3+
# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py
4+
# Copyright 2024 The vLLM team.
5+
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
"""PyTorch Idefics2 model."""
19+
20+
from typing import Optional
21+
22+
import torch
23+
from torch import nn
24+
from transformers.models.idefics2.configuration_idefics2 import (
25+
Idefics2Config, Idefics2VisionConfig)
26+
from xformers import ops as xops
27+
28+
from vllm.distributed import divide, get_tensor_model_parallel_world_size
29+
from vllm.model_executor.layers.activation import get_act_fn
30+
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
31+
QKVParallelLinear,
32+
RowParallelLinear)
33+
from vllm.model_executor.layers.quantization import QuantizationConfig
34+
35+
36+
class Idefics2VisionEmbeddings(nn.Module):
37+
"""
38+
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings
39+
` to enable images of variable
40+
resolution.
41+
42+
The modifications are adapted from [Patch n' Pack: NaViT, a Vision
43+
Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
44+
which allows treating images in their native aspect ratio and without the
45+
need to resize them to the same fixed size. In particular, we start from the
46+
original pre-trained SigLIP model(which uses images of fixed-size square
47+
images) and adapt it by training on images of variable resolutions.
48+
"""
49+
50+
def __init__(self, config: Idefics2VisionConfig):
51+
super().__init__()
52+
self.embed_dim = config.hidden_size
53+
self.image_size = config.image_size
54+
self.patch_size = config.patch_size
55+
self.patch_embedding = nn.Conv2d(
56+
in_channels=config.num_channels,
57+
out_channels=self.embed_dim,
58+
kernel_size=self.patch_size,
59+
stride=self.patch_size,
60+
padding="valid",
61+
)
62+
self.num_patches_per_side = self.image_size // self.patch_size
63+
self.num_patches = self.num_patches_per_side**2
64+
self.num_positions = self.num_patches
65+
self.position_embedding = nn.Embedding(self.num_positions,
66+
self.embed_dim)
67+
68+
def forward(
69+
self,
70+
pixel_values: torch.FloatTensor,
71+
patch_attention_mask: torch.BoolTensor,
72+
) -> torch.Tensor:
73+
batch_size, _, max_im_h, max_im_w = pixel_values.shape
74+
patch_embeds = self.patch_embedding(pixel_values)
75+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
76+
max_nb_patches_h, max_nb_patches_w = (
77+
max_im_h // self.patch_size,
78+
max_im_w // self.patch_size,
79+
)
80+
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0,
81+
1 / self.num_patches_per_side)
82+
position_ids = torch.full(size=(batch_size,
83+
max_nb_patches_h * max_nb_patches_w),
84+
fill_value=0)
85+
86+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
87+
nb_patches_h = p_attn_mask[:, 0].sum()
88+
nb_patches_w = p_attn_mask[0].sum()
89+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
90+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
91+
bucket_coords_h = torch.bucketize(fractional_coords_h,
92+
boundaries,
93+
right=True)
94+
bucket_coords_w = torch.bucketize(fractional_coords_w,
95+
boundaries,
96+
right=True)
97+
pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side +
98+
bucket_coords_w).flatten()
99+
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
100+
position_ids = position_ids.to(self.position_embedding.weight.device)
101+
embeddings = embeddings + self.position_embedding(position_ids)
102+
return embeddings
103+
104+
105+
class Idefics2VisionAttention(nn.Module):
106+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
107+
108+
def __init__(
109+
self,
110+
config: Idefics2Config,
111+
quant_config: Optional[QuantizationConfig] = None,
112+
):
113+
super().__init__()
114+
self.config = config
115+
self.embed_dim = config.hidden_size
116+
self.num_heads = config.num_attention_heads
117+
self.head_dim = self.embed_dim // self.num_heads
118+
if self.head_dim * self.num_heads != self.embed_dim:
119+
raise ValueError(
120+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501
121+
f" {self.num_heads}).")
122+
self.scale = self.head_dim**-0.5
123+
self.dropout = config.attention_dropout
124+
self.qkv_proj = QKVParallelLinear(
125+
self.embed_dim,
126+
self.head_dim,
127+
self.num_heads,
128+
quant_config=quant_config,
129+
)
130+
self.out_proj = RowParallelLinear(
131+
self.embed_dim,
132+
self.embed_dim,
133+
bias=True,
134+
quant_config=quant_config,
135+
)
136+
self.tp_size = get_tensor_model_parallel_world_size()
137+
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
138+
self.is_causal = False
139+
140+
def forward(
141+
self,
142+
hidden_states: torch.Tensor,
143+
) -> torch.Tensor:
144+
batch_size, q_len, _ = hidden_states.size()
145+
qkv, _ = self.qkv_proj(
146+
hidden_states
147+
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
148+
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
149+
query_states = query_states.view(batch_size, q_len,
150+
self.num_heads_per_partition,
151+
self.head_dim)
152+
key_states = key_states.view(batch_size, q_len,
153+
self.num_heads_per_partition,
154+
self.head_dim)
155+
value_states = value_states.view(batch_size, q_len,
156+
self.num_heads_per_partition,
157+
self.head_dim)
158+
# see: https://facebookresearch.github.io/xformers/components/ops.html
159+
out = xops.memory_efficient_attention_forward(
160+
query_states,
161+
key_states,
162+
value_states,
163+
p=self.dropout,
164+
scale=self.scale,
165+
)
166+
out = out.view(batch_size, q_len, -1)
167+
attn_output, _ = self.out_proj(out)
168+
return attn_output
169+
170+
171+
class Idefics2VisionMLP(nn.Module):
172+
173+
def __init__(
174+
self,
175+
config: Idefics2Config,
176+
quant_config: Optional[QuantizationConfig] = None,
177+
):
178+
super().__init__()
179+
self.config = config
180+
self.activation_fn = get_act_fn(config.hidden_act)
181+
self.fc1 = ColumnParallelLinear(
182+
config.hidden_size,
183+
config.intermediate_size,
184+
bias=True,
185+
quant_config=quant_config,
186+
)
187+
self.fc2 = RowParallelLinear(
188+
config.intermediate_size,
189+
config.hidden_size,
190+
bias=True,
191+
quant_config=quant_config,
192+
)
193+
194+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
195+
hidden_states, _ = self.fc1(hidden_states)
196+
hidden_states = self.activation_fn(hidden_states)
197+
hidden_states, _ = self.fc2(hidden_states)
198+
return hidden_states
199+
200+
201+
class Idefics2EncoderLayer(nn.Module):
202+
203+
def __init__(self, config: Idefics2Config):
204+
super().__init__()
205+
self.embed_dim = config.hidden_size
206+
self.self_attn = Idefics2VisionAttention(config)
207+
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
208+
eps=config.layer_norm_eps)
209+
self.mlp = Idefics2VisionMLP(config)
210+
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
211+
eps=config.layer_norm_eps)
212+
213+
def forward(
214+
self,
215+
hidden_states: torch.Tensor,
216+
) -> torch.Tensor:
217+
"""
218+
Args:
219+
hidden_states (`torch.FloatTensor`):
220+
Input to the layer of shape `(batch, seq_len, embed_dim)`.
221+
222+
"""
223+
residual = hidden_states
224+
hidden_states = self.layer_norm1(hidden_states)
225+
hidden_states = self.self_attn(hidden_states)
226+
hidden_states = residual + hidden_states
227+
residual = hidden_states
228+
hidden_states = self.layer_norm2(hidden_states)
229+
hidden_states = self.mlp(hidden_states)
230+
hidden_states = residual + hidden_states
231+
return hidden_states
232+
233+
234+
class Idefics2Encoder(nn.Module):
235+
"""
236+
Transformer encoder consisting of `config.num_hidden_layers` self attention
237+
layers. Each layer is a
238+
[`Idefics2EncoderLayer`].
239+
240+
Args:
241+
config: Idefics2Config
242+
"""
243+
244+
def __init__(self, config: Idefics2Config):
245+
super().__init__()
246+
self.config = config
247+
self.layers = nn.ModuleList([
248+
Idefics2EncoderLayer(config)
249+
for _ in range(config.num_hidden_layers)
250+
])
251+
252+
def forward(
253+
self,
254+
inputs_embeds: torch.Tensor,
255+
) -> torch.Tensor:
256+
r"""
257+
Args:
258+
inputs_embeds (torch.Tensor):
259+
Optionally, instead of passing `input_ids` you can choose to
260+
directly pass an embedded representation.
261+
This is useful if you want more control over how to convert
262+
`input_ids` indices into associated vectorsthan the model's
263+
internal embedding lookup matrix.
264+
"""
265+
hidden_states = inputs_embeds
266+
for encoder_layer in self.layers:
267+
layer_outputs = encoder_layer(hidden_states)
268+
hidden_states = layer_outputs
269+
return hidden_states
270+
271+
272+
class Idefics2VisionTransformer(nn.Module):
273+
274+
def __init__(self, config: Idefics2VisionConfig):
275+
super().__init__()
276+
embed_dim = config.hidden_size
277+
self.config = config
278+
self.embeddings = Idefics2VisionEmbeddings(config)
279+
self.encoder = Idefics2Encoder(config)
280+
self.post_layernorm = nn.LayerNorm(embed_dim,
281+
eps=config.layer_norm_eps)
282+
283+
def get_input_embeddings(self):
284+
return self.embeddings
285+
286+
def forward(
287+
self,
288+
pixel_values,
289+
patch_attention_mask: Optional[torch.BoolTensor] = None,
290+
) -> torch.tensor:
291+
hidden_states = self.embeddings(
292+
pixel_values=pixel_values,
293+
patch_attention_mask=patch_attention_mask)
294+
encoder_outputs = self.encoder(hidden_states)
295+
last_hidden_state = self.post_layernorm(encoder_outputs)
296+
return last_hidden_state

0 commit comments

Comments
 (0)