|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# An torch.export() friendly version of torchtune's positional embeddings. |
| 8 | +# Added torch._check() to make sure guards on symints are enforced. |
| 9 | +# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py |
| 10 | + |
| 11 | +import logging |
| 12 | +from typing import Any, Dict, Tuple |
| 13 | + |
| 14 | +import torch |
| 15 | +import torch.nn.functional as F |
| 16 | +from torch import nn |
| 17 | + |
| 18 | +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" |
| 19 | +logging.basicConfig(level=logging.INFO, format=FORMAT) |
| 20 | + |
| 21 | + |
| 22 | +class TilePositionalEmbedding(nn.Module): |
| 23 | + """ |
| 24 | + Positional embedding for tiles, different for every tile, same for every token within a tile. |
| 25 | +
|
| 26 | + Notice that tile is different from patch (token). For details, please check the documentation of |
| 27 | + :class:`torchtune.modules.vision_transformer.VisionTransformer`. |
| 28 | +
|
| 29 | + Args: |
| 30 | + max_num_tiles (int): The maximum number of tiles an image can be divided into. |
| 31 | + embed_dim (int): The dimensionality of each tile embedding. |
| 32 | + """ |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + max_num_tiles: int, |
| 37 | + embed_dim: int, |
| 38 | + ): |
| 39 | + super().__init__() |
| 40 | + self.max_num_tiles = max_num_tiles |
| 41 | + self.embed_dim = embed_dim |
| 42 | + |
| 43 | + scale = embed_dim**-0.5 |
| 44 | + self.embedding = nn.Parameter( |
| 45 | + scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) |
| 46 | + ) |
| 47 | + self.gate = nn.Parameter(torch.zeros(1)) |
| 48 | + |
| 49 | + # Register load hook to interpolate positional embeddings |
| 50 | + self._register_load_state_dict_pre_hook(self._load_state_dict_hook) |
| 51 | + |
| 52 | + # TODO: Switch to public method after 2.5 is stable |
| 53 | + @torch.no_grad() |
| 54 | + def _load_state_dict_hook( |
| 55 | + self, |
| 56 | + state_dict: Dict[str, Any], |
| 57 | + prefix: str, |
| 58 | + *args: Tuple[Any], |
| 59 | + **kwargs: Dict[str, Any], |
| 60 | + ): |
| 61 | + """ |
| 62 | + Interpolates positional embeddings to accomodate different number of tiles, |
| 63 | + in case the model was instantiated with different |
| 64 | + settings than the one you are loading the state dict from. |
| 65 | +
|
| 66 | + For more info, check self._dynamic_resize function. |
| 67 | +
|
| 68 | + Args: |
| 69 | + state_dict (Dict[str, Any]): The state dict to load. |
| 70 | + prefix (str): The prefix of the state dict. |
| 71 | + *args (Tuple[Any]): Additional positional arguments. |
| 72 | + **kwargs (Dict[str, Any]): Additional keyword arguments. |
| 73 | +
|
| 74 | + Raises: |
| 75 | + ValueError: if the shape of the loaded embedding is not compatible with the current embedding. |
| 76 | + ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. |
| 77 | + ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. |
| 78 | + """ |
| 79 | + |
| 80 | + embedding = state_dict.get(prefix + "embedding") |
| 81 | + |
| 82 | + if embedding is not None: |
| 83 | + |
| 84 | + # ckpt pos emb |
| 85 | + ( |
| 86 | + tgt_max_num_tiles_x, |
| 87 | + tgt_max_num_tiles_y, |
| 88 | + tgt_num_tokens, |
| 89 | + tgt_emb, |
| 90 | + ) = self.embedding.shape |
| 91 | + |
| 92 | + # instantiated pos emb |
| 93 | + ( |
| 94 | + inpt_max_num_tiles_x, |
| 95 | + inpt_max_num_tiles_y, |
| 96 | + inpt_num_tokens, |
| 97 | + inpt_emb, |
| 98 | + ) = state_dict[prefix + "embedding"].shape |
| 99 | + |
| 100 | + # sanity check |
| 101 | + if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb: |
| 102 | + raise ValueError( |
| 103 | + "Expected embedding shape to be (..., num_tokens, tgt_emb) to match" |
| 104 | + f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}" |
| 105 | + ) |
| 106 | + |
| 107 | + if inpt_max_num_tiles_x != inpt_max_num_tiles_y: |
| 108 | + raise ValueError( |
| 109 | + "Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found" |
| 110 | + f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}" |
| 111 | + ) |
| 112 | + |
| 113 | + # resize ckpt to match instantiated shape |
| 114 | + embedding_new = self._resize_position_embedding( |
| 115 | + embedding, tgt_max_num_tiles=tgt_max_num_tiles_x |
| 116 | + ) |
| 117 | + |
| 118 | + # update state dict |
| 119 | + state_dict[prefix + "embedding"] = embedding_new |
| 120 | + if embedding_new.shape != self.embedding.shape: |
| 121 | + raise ValueError( |
| 122 | + "Expected embedding shape and embedding_new.shape to match" |
| 123 | + f" but found shapes {self.embedding.shape} and {embedding_new.shape}" |
| 124 | + ) |
| 125 | + |
| 126 | + @staticmethod |
| 127 | + def _resize_position_embedding( |
| 128 | + embedding: torch.Tensor, tgt_max_num_tiles: int |
| 129 | + ) -> torch.Tensor: |
| 130 | + """ |
| 131 | + Interpolates positional embeddings to accomodate a different max_num_tiles. These |
| 132 | + are the only dimensions that changes during interpolation. |
| 133 | +
|
| 134 | + Args: |
| 135 | + embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim |
| 136 | + tgt_max_num_tiles (int): The number of tiles to resize to. |
| 137 | +
|
| 138 | + Returns: |
| 139 | + torch.Tensor: The resized embedding. |
| 140 | +
|
| 141 | + Example: |
| 142 | + >>> import torch |
| 143 | + >>> # create dummy embedding |
| 144 | + >>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float() |
| 145 | + >>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1) |
| 146 | + >>> print(resized_embed.shape) |
| 147 | + >>> torch.Size([1, 1, 2, 2]) |
| 148 | + """ |
| 149 | + # set max_num_tiles to the last dimension |
| 150 | + embedding = embedding.permute(2, 3, 0, 1) |
| 151 | + |
| 152 | + embedding = F.interpolate( |
| 153 | + embedding, |
| 154 | + size=(tgt_max_num_tiles, tgt_max_num_tiles), |
| 155 | + mode="bilinear", |
| 156 | + align_corners=True, |
| 157 | + ) |
| 158 | + # permute to the original shape |
| 159 | + embedding = embedding.permute(2, 3, 0, 1) |
| 160 | + return embedding |
| 161 | + |
| 162 | + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: |
| 163 | + """ |
| 164 | + args: |
| 165 | + x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). |
| 166 | + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), |
| 167 | + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). |
| 168 | + returns: |
| 169 | + torch.Tensor: The input tensor with added positional embeddings. |
| 170 | + """ |
| 171 | + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape |
| 172 | + torch._check(n_tiles <= self.max_num_tiles) |
| 173 | + |
| 174 | + for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): |
| 175 | + # When we batch images, all are padded to the same amount of tiles. |
| 176 | + # The aspect_ratio lets us know the non padded tiles for each image. |
| 177 | + # We only add positional encoding to those. |
| 178 | + n_tiles_h = n_tiles_h.item() |
| 179 | + n_tiles_w = n_tiles_w.item() |
| 180 | + |
| 181 | + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) |
| 182 | + |
| 183 | + # We get only the positional encoding for non padded tiles, |
| 184 | + # i.e. n_tiles_h, n_tiles_w. |
| 185 | + torch._check_is_size(n_tiles_h) |
| 186 | + torch._check_is_size(n_tiles_w) |
| 187 | + torch._check(n_tiles_h >= 1) |
| 188 | + torch._check(n_tiles_w >= 1) |
| 189 | + torch._check(n_tiles_h <= self.max_num_tiles) |
| 190 | + torch._check(n_tiles_w <= self.max_num_tiles) |
| 191 | + # TODO: Remove this once pytorch/pytorch#120288 is fixed |
| 192 | + padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) |
| 193 | + pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] |
| 194 | + |
| 195 | + # We need to do a clone here in order to make this model export |
| 196 | + # friendly as the reshape is collapsing dim 0 and dim 1 into a |
| 197 | + # single dim. |
| 198 | + pos_embed = pos_embed.clone() |
| 199 | + pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim) |
| 200 | + |
| 201 | + x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) |
| 202 | + torch._check_is_size(n_non_padded_tiles) |
| 203 | + torch._check(n_non_padded_tiles < x.size(1)) |
| 204 | + x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() |
| 205 | + x = x[:, :n_tiles, :, :] |
| 206 | + |
| 207 | + return x |
| 208 | + |
| 209 | + |
| 210 | +def replace_tile_positional_embedding(model: nn.Module) -> nn.Module: |
| 211 | + """ |
| 212 | + Replace the tile positional embedding from torchtune with an export-friendly one. |
| 213 | + Recursively searches the submodules of the model and replaces the tile positional embedding if found. |
| 214 | + Args: |
| 215 | + model (nn.Module): The model to replace the tile positional embedding in. |
| 216 | +
|
| 217 | + Returns: |
| 218 | + nn.Module: The model after replacing the tile positional embedding. |
| 219 | +
|
| 220 | + """ |
| 221 | + from torchtune.models.clip._position_embeddings import ( |
| 222 | + TilePositionalEmbedding as TuneTilePositionalEmbedding, |
| 223 | + ) |
| 224 | + |
| 225 | + for name, module in model.named_children(): |
| 226 | + if isinstance(module, TuneTilePositionalEmbedding): |
| 227 | + logging.info( |
| 228 | + f"Replacing tile positional embedding in {name} with export-friendly one." |
| 229 | + ) |
| 230 | + max_num_tiles, _, _, embed_dim = module.embedding.shape |
| 231 | + mod = TilePositionalEmbedding( |
| 232 | + max_num_tiles=max_num_tiles, |
| 233 | + embed_dim=embed_dim, |
| 234 | + ) |
| 235 | + mod.load_state_dict(module.state_dict()) |
| 236 | + setattr( |
| 237 | + model, |
| 238 | + name, |
| 239 | + mod, |
| 240 | + ) |
| 241 | + else: |
| 242 | + replace_tile_positional_embedding(module) |
| 243 | + return model |
0 commit comments