|
| 1 | +# coding=utf-8 |
| 2 | + |
| 3 | +# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py |
| 4 | +# Copyright 2024 The vLLM team. |
| 5 | +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. |
| 6 | +# |
| 7 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +# you may not use this file except in compliance with the License. |
| 9 | +# You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, software |
| 14 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +# See the License for the specific language governing permissions and |
| 17 | +# limitations under the License. |
| 18 | +"""PyTorch Idefics2 model.""" |
| 19 | + |
| 20 | +from typing import Optional |
| 21 | + |
| 22 | +import torch |
| 23 | +from torch import nn |
| 24 | +from transformers.models.idefics2.configuration_idefics2 import ( |
| 25 | + Idefics2Config, Idefics2VisionConfig) |
| 26 | +from xformers import ops as xops |
| 27 | + |
| 28 | +from vllm.distributed import divide, get_tensor_model_parallel_world_size |
| 29 | +from vllm.model_executor.layers.activation import get_act_fn |
| 30 | +from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
| 31 | + QKVParallelLinear, |
| 32 | + RowParallelLinear) |
| 33 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
| 34 | + |
| 35 | + |
| 36 | +class Idefics2VisionEmbeddings(nn.Module): |
| 37 | + """ |
| 38 | + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings |
| 39 | + ` to enable images of variable |
| 40 | + resolution. |
| 41 | +
|
| 42 | + The modifications are adapted from [Patch n' Pack: NaViT, a Vision |
| 43 | + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) |
| 44 | + which allows treating images in their native aspect ratio and without the |
| 45 | + need to resize them to the same fixed size. In particular, we start from the |
| 46 | + original pre-trained SigLIP model(which uses images of fixed-size square |
| 47 | + images) and adapt it by training on images of variable resolutions. |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__(self, config: Idefics2VisionConfig): |
| 51 | + super().__init__() |
| 52 | + self.embed_dim = config.hidden_size |
| 53 | + self.image_size = config.image_size |
| 54 | + self.patch_size = config.patch_size |
| 55 | + self.patch_embedding = nn.Conv2d( |
| 56 | + in_channels=config.num_channels, |
| 57 | + out_channels=self.embed_dim, |
| 58 | + kernel_size=self.patch_size, |
| 59 | + stride=self.patch_size, |
| 60 | + padding="valid", |
| 61 | + ) |
| 62 | + self.num_patches_per_side = self.image_size // self.patch_size |
| 63 | + self.num_patches = self.num_patches_per_side**2 |
| 64 | + self.num_positions = self.num_patches |
| 65 | + self.position_embedding = nn.Embedding(self.num_positions, |
| 66 | + self.embed_dim) |
| 67 | + |
| 68 | + def forward( |
| 69 | + self, |
| 70 | + pixel_values: torch.FloatTensor, |
| 71 | + patch_attention_mask: torch.BoolTensor, |
| 72 | + ) -> torch.Tensor: |
| 73 | + batch_size, _, max_im_h, max_im_w = pixel_values.shape |
| 74 | + patch_embeds = self.patch_embedding(pixel_values) |
| 75 | + embeddings = patch_embeds.flatten(2).transpose(1, 2) |
| 76 | + max_nb_patches_h, max_nb_patches_w = ( |
| 77 | + max_im_h // self.patch_size, |
| 78 | + max_im_w // self.patch_size, |
| 79 | + ) |
| 80 | + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, |
| 81 | + 1 / self.num_patches_per_side) |
| 82 | + position_ids = torch.full(size=(batch_size, |
| 83 | + max_nb_patches_h * max_nb_patches_w), |
| 84 | + fill_value=0) |
| 85 | + |
| 86 | + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): |
| 87 | + nb_patches_h = p_attn_mask[:, 0].sum() |
| 88 | + nb_patches_w = p_attn_mask[0].sum() |
| 89 | + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) |
| 90 | + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) |
| 91 | + bucket_coords_h = torch.bucketize(fractional_coords_h, |
| 92 | + boundaries, |
| 93 | + right=True) |
| 94 | + bucket_coords_w = torch.bucketize(fractional_coords_w, |
| 95 | + boundaries, |
| 96 | + right=True) |
| 97 | + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + |
| 98 | + bucket_coords_w).flatten() |
| 99 | + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids |
| 100 | + position_ids = position_ids.to(self.position_embedding.weight.device) |
| 101 | + embeddings = embeddings + self.position_embedding(position_ids) |
| 102 | + return embeddings |
| 103 | + |
| 104 | + |
| 105 | +class Idefics2VisionAttention(nn.Module): |
| 106 | + """Multi-headed attention from 'Attention Is All You Need' paper""" |
| 107 | + |
| 108 | + def __init__( |
| 109 | + self, |
| 110 | + config: Idefics2Config, |
| 111 | + quant_config: Optional[QuantizationConfig] = None, |
| 112 | + ): |
| 113 | + super().__init__() |
| 114 | + self.config = config |
| 115 | + self.embed_dim = config.hidden_size |
| 116 | + self.num_heads = config.num_attention_heads |
| 117 | + self.head_dim = self.embed_dim // self.num_heads |
| 118 | + if self.head_dim * self.num_heads != self.embed_dim: |
| 119 | + raise ValueError( |
| 120 | + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 |
| 121 | + f" {self.num_heads}).") |
| 122 | + self.scale = self.head_dim**-0.5 |
| 123 | + self.dropout = config.attention_dropout |
| 124 | + self.qkv_proj = QKVParallelLinear( |
| 125 | + self.embed_dim, |
| 126 | + self.head_dim, |
| 127 | + self.num_heads, |
| 128 | + quant_config=quant_config, |
| 129 | + ) |
| 130 | + self.out_proj = RowParallelLinear( |
| 131 | + self.embed_dim, |
| 132 | + self.embed_dim, |
| 133 | + bias=True, |
| 134 | + quant_config=quant_config, |
| 135 | + ) |
| 136 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 137 | + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) |
| 138 | + self.is_causal = False |
| 139 | + |
| 140 | + def forward( |
| 141 | + self, |
| 142 | + hidden_states: torch.Tensor, |
| 143 | + ) -> torch.Tensor: |
| 144 | + batch_size, q_len, _ = hidden_states.size() |
| 145 | + qkv, _ = self.qkv_proj( |
| 146 | + hidden_states |
| 147 | + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim |
| 148 | + query_states, key_states, value_states = qkv.chunk(3, dim=-1) |
| 149 | + query_states = query_states.view(batch_size, q_len, |
| 150 | + self.num_heads_per_partition, |
| 151 | + self.head_dim) |
| 152 | + key_states = key_states.view(batch_size, q_len, |
| 153 | + self.num_heads_per_partition, |
| 154 | + self.head_dim) |
| 155 | + value_states = value_states.view(batch_size, q_len, |
| 156 | + self.num_heads_per_partition, |
| 157 | + self.head_dim) |
| 158 | + # see: https://facebookresearch.github.io/xformers/components/ops.html |
| 159 | + out = xops.memory_efficient_attention_forward( |
| 160 | + query_states, |
| 161 | + key_states, |
| 162 | + value_states, |
| 163 | + p=self.dropout, |
| 164 | + scale=self.scale, |
| 165 | + ) |
| 166 | + out = out.view(batch_size, q_len, -1) |
| 167 | + attn_output, _ = self.out_proj(out) |
| 168 | + return attn_output |
| 169 | + |
| 170 | + |
| 171 | +class Idefics2VisionMLP(nn.Module): |
| 172 | + |
| 173 | + def __init__( |
| 174 | + self, |
| 175 | + config: Idefics2Config, |
| 176 | + quant_config: Optional[QuantizationConfig] = None, |
| 177 | + ): |
| 178 | + super().__init__() |
| 179 | + self.config = config |
| 180 | + self.activation_fn = get_act_fn(config.hidden_act) |
| 181 | + self.fc1 = ColumnParallelLinear( |
| 182 | + config.hidden_size, |
| 183 | + config.intermediate_size, |
| 184 | + bias=True, |
| 185 | + quant_config=quant_config, |
| 186 | + ) |
| 187 | + self.fc2 = RowParallelLinear( |
| 188 | + config.intermediate_size, |
| 189 | + config.hidden_size, |
| 190 | + bias=True, |
| 191 | + quant_config=quant_config, |
| 192 | + ) |
| 193 | + |
| 194 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 195 | + hidden_states, _ = self.fc1(hidden_states) |
| 196 | + hidden_states = self.activation_fn(hidden_states) |
| 197 | + hidden_states, _ = self.fc2(hidden_states) |
| 198 | + return hidden_states |
| 199 | + |
| 200 | + |
| 201 | +class Idefics2EncoderLayer(nn.Module): |
| 202 | + |
| 203 | + def __init__(self, config: Idefics2Config): |
| 204 | + super().__init__() |
| 205 | + self.embed_dim = config.hidden_size |
| 206 | + self.self_attn = Idefics2VisionAttention(config) |
| 207 | + self.layer_norm1 = nn.LayerNorm(self.embed_dim, |
| 208 | + eps=config.layer_norm_eps) |
| 209 | + self.mlp = Idefics2VisionMLP(config) |
| 210 | + self.layer_norm2 = nn.LayerNorm(self.embed_dim, |
| 211 | + eps=config.layer_norm_eps) |
| 212 | + |
| 213 | + def forward( |
| 214 | + self, |
| 215 | + hidden_states: torch.Tensor, |
| 216 | + ) -> torch.Tensor: |
| 217 | + """ |
| 218 | + Args: |
| 219 | + hidden_states (`torch.FloatTensor`): |
| 220 | + Input to the layer of shape `(batch, seq_len, embed_dim)`. |
| 221 | +
|
| 222 | + """ |
| 223 | + residual = hidden_states |
| 224 | + hidden_states = self.layer_norm1(hidden_states) |
| 225 | + hidden_states = self.self_attn(hidden_states) |
| 226 | + hidden_states = residual + hidden_states |
| 227 | + residual = hidden_states |
| 228 | + hidden_states = self.layer_norm2(hidden_states) |
| 229 | + hidden_states = self.mlp(hidden_states) |
| 230 | + hidden_states = residual + hidden_states |
| 231 | + return hidden_states |
| 232 | + |
| 233 | + |
| 234 | +class Idefics2Encoder(nn.Module): |
| 235 | + """ |
| 236 | + Transformer encoder consisting of `config.num_hidden_layers` self attention |
| 237 | + layers. Each layer is a |
| 238 | + [`Idefics2EncoderLayer`]. |
| 239 | +
|
| 240 | + Args: |
| 241 | + config: Idefics2Config |
| 242 | + """ |
| 243 | + |
| 244 | + def __init__(self, config: Idefics2Config): |
| 245 | + super().__init__() |
| 246 | + self.config = config |
| 247 | + self.layers = nn.ModuleList([ |
| 248 | + Idefics2EncoderLayer(config) |
| 249 | + for _ in range(config.num_hidden_layers) |
| 250 | + ]) |
| 251 | + |
| 252 | + def forward( |
| 253 | + self, |
| 254 | + inputs_embeds: torch.Tensor, |
| 255 | + ) -> torch.Tensor: |
| 256 | + r""" |
| 257 | + Args: |
| 258 | + inputs_embeds (torch.Tensor): |
| 259 | + Optionally, instead of passing `input_ids` you can choose to |
| 260 | + directly pass an embedded representation. |
| 261 | + This is useful if you want more control over how to convert |
| 262 | + `input_ids` indices into associated vectorsthan the model's |
| 263 | + internal embedding lookup matrix. |
| 264 | + """ |
| 265 | + hidden_states = inputs_embeds |
| 266 | + for encoder_layer in self.layers: |
| 267 | + layer_outputs = encoder_layer(hidden_states) |
| 268 | + hidden_states = layer_outputs |
| 269 | + return hidden_states |
| 270 | + |
| 271 | + |
| 272 | +class Idefics2VisionTransformer(nn.Module): |
| 273 | + |
| 274 | + def __init__(self, config: Idefics2VisionConfig): |
| 275 | + super().__init__() |
| 276 | + embed_dim = config.hidden_size |
| 277 | + self.config = config |
| 278 | + self.embeddings = Idefics2VisionEmbeddings(config) |
| 279 | + self.encoder = Idefics2Encoder(config) |
| 280 | + self.post_layernorm = nn.LayerNorm(embed_dim, |
| 281 | + eps=config.layer_norm_eps) |
| 282 | + |
| 283 | + def get_input_embeddings(self): |
| 284 | + return self.embeddings |
| 285 | + |
| 286 | + def forward( |
| 287 | + self, |
| 288 | + pixel_values, |
| 289 | + patch_attention_mask: Optional[torch.BoolTensor] = None, |
| 290 | + ) -> torch.tensor: |
| 291 | + hidden_states = self.embeddings( |
| 292 | + pixel_values=pixel_values, |
| 293 | + patch_attention_mask=patch_attention_mask) |
| 294 | + encoder_outputs = self.encoder(hidden_states) |
| 295 | + last_hidden_state = self.post_layernorm(encoder_outputs) |
| 296 | + return last_hidden_state |
0 commit comments