Skip to content

Commit f7e26d7

Browse files
[llama-mm] Add export-friendly tile position embedding (#6671)
Summary: Before we make a decision on whether torchtune takes this export-friendly version of `TilePositionEmbedding`, we put it under `extension/llm` so that users can start to use it. Added unit tests to make sure the behavior is the same as the reference implementation in torchtune and export/AOTI/ET all working properly. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fe65ec6 Pull Request resolved: #6650 Co-authored-by: Mengwei Liu <larryliu@meta.com>
1 parent 735e019 commit f7e26d7

File tree

6 files changed

+391
-0
lines changed

6 files changed

+391
-0
lines changed

extension/llm/modules/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
## Export Friendly Modules
2+
3+
Modules in this directory are:
4+
* Extending `torch.nn.Module`.
5+
* Guranteed to work out of the box with `torch.export.export()` and `torch.aot_compile()`.
6+
* Guranteed to be able to work with ExecuTorch.
7+
8+
All modules should be covered by unit tests to make sure they are:
9+
1. giving the same output as the reference implementation in PyTorch or torchtune
10+
2. export friendly
11+
3. AOTI friendly
12+
4. ExecuTorch friendly
13+
14+
Notice that these modules are subject to change (may upstream to torchtune) so proceed with caution.

extension/llm/modules/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from ._position_embeddings import (
8+
replace_tile_positional_embedding,
9+
TilePositionalEmbedding,
10+
)
11+
12+
__all__ = [
13+
"TilePositionalEmbedding",
14+
"replace_tile_positional_embedding",
15+
]
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# An torch.export() friendly version of torchtune's positional embeddings.
8+
# Added torch._check() to make sure guards on symints are enforced.
9+
# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py
10+
11+
import logging
12+
from typing import Any, Dict, Tuple
13+
14+
import torch
15+
import torch.nn.functional as F
16+
from torch import nn
17+
18+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
19+
logging.basicConfig(level=logging.INFO, format=FORMAT)
20+
21+
22+
class TilePositionalEmbedding(nn.Module):
23+
"""
24+
Positional embedding for tiles, different for every tile, same for every token within a tile.
25+
26+
Notice that tile is different from patch (token). For details, please check the documentation of
27+
:class:`torchtune.modules.vision_transformer.VisionTransformer`.
28+
29+
Args:
30+
max_num_tiles (int): The maximum number of tiles an image can be divided into.
31+
embed_dim (int): The dimensionality of each tile embedding.
32+
"""
33+
34+
def __init__(
35+
self,
36+
max_num_tiles: int,
37+
embed_dim: int,
38+
):
39+
super().__init__()
40+
self.max_num_tiles = max_num_tiles
41+
self.embed_dim = embed_dim
42+
43+
scale = embed_dim**-0.5
44+
self.embedding = nn.Parameter(
45+
scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim)
46+
)
47+
self.gate = nn.Parameter(torch.zeros(1))
48+
49+
# Register load hook to interpolate positional embeddings
50+
self._register_load_state_dict_pre_hook(self._load_state_dict_hook)
51+
52+
# TODO: Switch to public method after 2.5 is stable
53+
@torch.no_grad()
54+
def _load_state_dict_hook(
55+
self,
56+
state_dict: Dict[str, Any],
57+
prefix: str,
58+
*args: Tuple[Any],
59+
**kwargs: Dict[str, Any],
60+
):
61+
"""
62+
Interpolates positional embeddings to accomodate different number of tiles,
63+
in case the model was instantiated with different
64+
settings than the one you are loading the state dict from.
65+
66+
For more info, check self._dynamic_resize function.
67+
68+
Args:
69+
state_dict (Dict[str, Any]): The state dict to load.
70+
prefix (str): The prefix of the state dict.
71+
*args (Tuple[Any]): Additional positional arguments.
72+
**kwargs (Dict[str, Any]): Additional keyword arguments.
73+
74+
Raises:
75+
ValueError: if the shape of the loaded embedding is not compatible with the current embedding.
76+
ValueError: if max_num_tiles_x, max_num_tiles_y are not equal.
77+
ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
78+
"""
79+
80+
embedding = state_dict.get(prefix + "embedding")
81+
82+
if embedding is not None:
83+
84+
# ckpt pos emb
85+
(
86+
tgt_max_num_tiles_x,
87+
tgt_max_num_tiles_y,
88+
tgt_num_tokens,
89+
tgt_emb,
90+
) = self.embedding.shape
91+
92+
# instantiated pos emb
93+
(
94+
inpt_max_num_tiles_x,
95+
inpt_max_num_tiles_y,
96+
inpt_num_tokens,
97+
inpt_emb,
98+
) = state_dict[prefix + "embedding"].shape
99+
100+
# sanity check
101+
if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb:
102+
raise ValueError(
103+
"Expected embedding shape to be (..., num_tokens, tgt_emb) to match"
104+
f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}"
105+
)
106+
107+
if inpt_max_num_tiles_x != inpt_max_num_tiles_y:
108+
raise ValueError(
109+
"Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found"
110+
f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}"
111+
)
112+
113+
# resize ckpt to match instantiated shape
114+
embedding_new = self._resize_position_embedding(
115+
embedding, tgt_max_num_tiles=tgt_max_num_tiles_x
116+
)
117+
118+
# update state dict
119+
state_dict[prefix + "embedding"] = embedding_new
120+
if embedding_new.shape != self.embedding.shape:
121+
raise ValueError(
122+
"Expected embedding shape and embedding_new.shape to match"
123+
f" but found shapes {self.embedding.shape} and {embedding_new.shape}"
124+
)
125+
126+
@staticmethod
127+
def _resize_position_embedding(
128+
embedding: torch.Tensor, tgt_max_num_tiles: int
129+
) -> torch.Tensor:
130+
"""
131+
Interpolates positional embeddings to accomodate a different max_num_tiles. These
132+
are the only dimensions that changes during interpolation.
133+
134+
Args:
135+
embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim
136+
tgt_max_num_tiles (int): The number of tiles to resize to.
137+
138+
Returns:
139+
torch.Tensor: The resized embedding.
140+
141+
Example:
142+
>>> import torch
143+
>>> # create dummy embedding
144+
>>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float()
145+
>>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1)
146+
>>> print(resized_embed.shape)
147+
>>> torch.Size([1, 1, 2, 2])
148+
"""
149+
# set max_num_tiles to the last dimension
150+
embedding = embedding.permute(2, 3, 0, 1)
151+
152+
embedding = F.interpolate(
153+
embedding,
154+
size=(tgt_max_num_tiles, tgt_max_num_tiles),
155+
mode="bilinear",
156+
align_corners=True,
157+
)
158+
# permute to the original shape
159+
embedding = embedding.permute(2, 3, 0, 1)
160+
return embedding
161+
162+
def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
163+
"""
164+
args:
165+
x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
166+
aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
167+
representing the aspect ratio of the image before tile-cropping, e.g. (2,1).
168+
returns:
169+
torch.Tensor: The input tensor with added positional embeddings.
170+
"""
171+
bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape
172+
torch._check(n_tiles <= self.max_num_tiles)
173+
174+
for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio):
175+
# When we batch images, all are padded to the same amount of tiles.
176+
# The aspect_ratio lets us know the non padded tiles for each image.
177+
# We only add positional encoding to those.
178+
n_tiles_h = n_tiles_h.item()
179+
n_tiles_w = n_tiles_w.item()
180+
181+
n_non_padded_tiles = int(n_tiles_h * n_tiles_w)
182+
183+
# We get only the positional encoding for non padded tiles,
184+
# i.e. n_tiles_h, n_tiles_w.
185+
torch._check_is_size(n_tiles_h)
186+
torch._check_is_size(n_tiles_w)
187+
torch._check(n_tiles_h >= 1)
188+
torch._check(n_tiles_w >= 1)
189+
torch._check(n_tiles_h <= self.max_num_tiles)
190+
torch._check(n_tiles_w <= self.max_num_tiles)
191+
# TODO: Remove this once pytorch/pytorch#120288 is fixed
192+
padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1))
193+
pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]
194+
195+
# We need to do a clone here in order to make this model export
196+
# friendly as the reshape is collapsing dim 0 and dim 1 into a
197+
# single dim.
198+
pos_embed = pos_embed.clone()
199+
pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim)
200+
201+
x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0))
202+
torch._check_is_size(n_non_padded_tiles)
203+
torch._check(n_non_padded_tiles < x.size(1))
204+
x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh()
205+
x = x[:, :n_tiles, :, :]
206+
207+
return x
208+
209+
210+
def replace_tile_positional_embedding(model: nn.Module) -> nn.Module:
211+
"""
212+
Replace the tile positional embedding from torchtune with an export-friendly one.
213+
Recursively searches the submodules of the model and replaces the tile positional embedding if found.
214+
Args:
215+
model (nn.Module): The model to replace the tile positional embedding in.
216+
217+
Returns:
218+
nn.Module: The model after replacing the tile positional embedding.
219+
220+
"""
221+
from torchtune.models.clip._position_embeddings import (
222+
TilePositionalEmbedding as TuneTilePositionalEmbedding,
223+
)
224+
225+
for name, module in model.named_children():
226+
if isinstance(module, TuneTilePositionalEmbedding):
227+
logging.info(
228+
f"Replacing tile positional embedding in {name} with export-friendly one."
229+
)
230+
max_num_tiles, _, _, embed_dim = module.embedding.shape
231+
mod = TilePositionalEmbedding(
232+
max_num_tiles=max_num_tiles,
233+
embed_dim=embed_dim,
234+
)
235+
mod.load_state_dict(module.state_dict())
236+
setattr(
237+
model,
238+
name,
239+
mod,
240+
)
241+
else:
242+
replace_tile_positional_embedding(module)
243+
return model

extension/llm/modules/test/__init__.py

Whitespace-only changes.
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import tempfile
9+
import unittest
10+
11+
import torch
12+
from executorch.exir import EdgeCompileConfig, to_edge
13+
from executorch.extension.llm.modules import (
14+
replace_tile_positional_embedding,
15+
TilePositionalEmbedding,
16+
)
17+
from executorch.runtime import Runtime
18+
from torch._inductor.package import load_package, package_aoti
19+
from torchtune.models.clip import TilePositionalEmbedding as TuneTilePositionalEmbedding
20+
21+
22+
class TilePositionalEmbeddingTest(unittest.TestCase):
23+
def setUp(self):
24+
super().setUp()
25+
self.tpe = TilePositionalEmbedding(4, 1280)
26+
self.ref_tpe = TuneTilePositionalEmbedding(4, 1280)
27+
self.x = torch.randn(1, 4, 1600, 1280)
28+
self.aspect_ratio = torch.tensor([[1, 1]])
29+
num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4)
30+
num_tokens = torch.export.Dim("num_tokens", min=1, max=1600)
31+
32+
self.dynamic_shape = {
33+
0: 1, # batch
34+
1: num_tiles_dim, # num tiles
35+
2: num_tokens, # num tokens
36+
3: 1280, # embedding dim
37+
}
38+
39+
def test_tile_positional_embedding_smoke(self):
40+
y = self.tpe(self.x, self.aspect_ratio)
41+
ref_y = self.ref_tpe(self.x, self.aspect_ratio)
42+
43+
self.assertTrue(torch.allclose(y, ref_y))
44+
45+
def test_tile_positional_embedding_export(self):
46+
47+
tpe_ep = torch.export.export(
48+
self.tpe,
49+
(self.x, self.aspect_ratio),
50+
dynamic_shapes=(
51+
self.dynamic_shape,
52+
None,
53+
), # assuming aspect ratio is static
54+
)
55+
56+
y = tpe_ep.module()(self.x, self.aspect_ratio)
57+
ref_y = self.ref_tpe(self.x, self.aspect_ratio)
58+
59+
self.assertTrue(torch.allclose(y, ref_y))
60+
61+
def test_tile_positional_embedding_aoti(self):
62+
so = torch._export.aot_compile(
63+
self.tpe,
64+
args=(self.x, self.aspect_ratio),
65+
options={"aot_inductor.package": True},
66+
dynamic_shapes=(
67+
self.dynamic_shape,
68+
None,
69+
), # assuming aspect ratio is static
70+
)
71+
with tempfile.TemporaryDirectory() as tmpdir:
72+
path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so)
73+
tpe_aoti = load_package(path)
74+
75+
y = tpe_aoti(self.x, self.aspect_ratio)
76+
ref_y = self.ref_tpe(self.x, self.aspect_ratio)
77+
78+
self.assertTrue(torch.allclose(y, ref_y))
79+
80+
def test_tile_positional_embedding_et(self):
81+
tpe_ep = torch.export.export(
82+
self.tpe,
83+
(self.x, self.aspect_ratio),
84+
dynamic_shapes=(
85+
self.dynamic_shape,
86+
None,
87+
), # assuming aspect ratio is static
88+
)
89+
et_program = to_edge(
90+
tpe_ep,
91+
compile_config=EdgeCompileConfig(
92+
_core_aten_ops_exception_list=[
93+
torch.ops.aten.sym_constrain_range_for_size.default,
94+
torch.ops.aten._assert_scalar.default,
95+
torch.ops.aten._local_scalar_dense.default,
96+
]
97+
),
98+
).to_executorch()
99+
runtime = Runtime.get()
100+
program = runtime.load_program(et_program.buffer)
101+
method = program.load_method("forward")
102+
y = method.execute((self.x, self.aspect_ratio))
103+
ref_y = self.ref_tpe(self.x, self.aspect_ratio)
104+
105+
self.assertTrue(torch.allclose(y[0], ref_y))
106+
107+
def test_replace_tile_positional_embedding(self):
108+
class Module(torch.nn.Module):
109+
def __init__(self):
110+
super().__init__()
111+
self.tpe = TuneTilePositionalEmbedding(4, 1280)
112+
113+
def forward(self, x, aspect_ratio):
114+
return self.tpe(x, aspect_ratio)
115+
116+
m = Module()
117+
m = replace_tile_positional_embedding(m)
118+
self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding))

0 commit comments

Comments
 (0)