Skip to content

Commit

Permalink
[FIX] MM Eval Mask Sizes (#1920)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbontrager authored Oct 30, 2024
1 parent 4fb2464 commit a1bcb97
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 29 deletions.
5 changes: 4 additions & 1 deletion recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def generate(self, cfg: DictConfig):
batch = {}
if is_multimodal_input:
batch = padded_collate_tiled_images_and_mask(
[model_inputs], pad_direction="left", pad_max_images=1
[model_inputs],
pad_direction="left",
pad_max_images=1,
pad_max_tiles=self.model_transform.max_num_tiles,
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
prompt = batch.pop("tokens").to(self._device)
Expand Down
1 change: 1 addition & 0 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def tok_batch_multimodal_encode(
all_encoded_messages,
pad_direction="left",
pad_max_images=self._max_images_per_sample,
pad_max_tiles=self._transform.max_num_tiles,
)
utils.batch_to_device(tok_batch, self.device)

Expand Down
40 changes: 28 additions & 12 deletions tests/torchtune/data/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,41 @@ def test_batch_pad_sequence(self):


class TestPaddedCollateTiledImagesAndMask:
img_shape = 1, 1, 1
tokens_per_tile = 5

@pytest.fixture
def batch(self):
c, h, w = self.img_shape
s = self.tokens_per_tile
return [
{
"tokens": [1, 2, 1, 3],
"labels": [4, 5, 6, 7],
"encoder_input": {
"images": [torch.ones(2, 1, 1, 1), torch.ones(3, 1, 1, 1)],
"images": [torch.ones(2, c, h, w), torch.ones(3, c, h, w)],
"aspect_ratio": [torch.tensor([1, 2]), torch.tensor([1, 3])],
},
"encoder_mask": [torch.ones(4, 5 * 2), torch.ones(4, 5 * 3)],
"encoder_mask": [torch.ones(4, s * 2), torch.ones(4, s * 3)],
},
{
"tokens": [1, 4],
"labels": [8, 9],
"encoder_input": {
"images": [torch.ones(4, 1, 1, 1)],
"images": [torch.ones(4, c, h, w)],
"aspect_ratio": [torch.tensor([2, 2])],
},
"encoder_mask": [torch.ones(2, 5 * 4)],
"encoder_mask": [torch.ones(2, s * 4)],
},
]

def test_right_pad_sequence(self, batch):
actual = padded_collate_tiled_images_and_mask(
batch=batch, padding_idx=0, ignore_idx=-100, pad_direction="right"
)
imgs, tiles = actual["encoder_input"]["images"].shape[1:3]
seq_len = actual["encoder_mask"].shape[-1]
assert imgs * tiles * self.tokens_per_tile == seq_len

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
Expand Down Expand Up @@ -126,28 +134,36 @@ def test_left_pad_sequence(self, batch):
ignore_idx=-100,
pad_direction="left",
pad_max_images=4,
pad_max_tiles=5,
)
imgs, tiles = actual["encoder_input"]["images"].shape[1:3]
seq_len = actual["encoder_mask"].shape[-1]
assert 5 * 4 * self.tokens_per_tile == seq_len

mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 10)], dim=1)
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5)], dim=1)
# pad 3 extra tiles
mask_1 = torch.concat([torch.ones(4, 5 * 2), torch.zeros(4, 5 * 3)], dim=1)
# pad 2 extra tiles
mask_2 = torch.concat([torch.ones(4, 5 * 3), torch.zeros(4, 5 * 2)], dim=1)
# Left pad text tokens
mask_3 = torch.concat([torch.zeros(2, 20), torch.ones(2, 5 * 4)], dim=0)
mask_3 = F.pad(mask_3, (0, 5), value=0) # pad 5th tile
sample_1 = torch.stack([mask_1, mask_2])
sample_2 = torch.stack([mask_3, torch.zeros(4, 20)])
sample_2 = torch.stack([mask_3, torch.zeros(4, 25)])
expected_mask = torch.stack([sample_1, sample_2]).view(2, 4, -1)
expected_mask = F.pad(expected_mask, (0, 40), value=0)
expected_mask = F.pad(expected_mask, (0, 50), value=0)

expected = {
"tokens": torch.tensor([[1, 2, 1, 3], [0, 0, 1, 4]]),
"encoder_input": {
"images": torch.tensor(
[
[
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]], [[[0.0]]]],
],
[
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
[[[[1.0]]], [[[1.0]]], [[[1.0]]], [[[1.0]]], [[[0.0]]]],
[[[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]], [[[0.0]]]],
],
]
),
Expand Down
4 changes: 1 addition & 3 deletions tests/torchtune/modules/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


IMAGE_TOKEN_ID = 1
MAX_NUM_TILES = 4


class TestVisionCrossAttentionMask:
Expand Down Expand Up @@ -54,7 +53,6 @@ def cross_attn_mask_transform(self, tile_size, patch_size):
tile_size=tile_size,
patch_size=patch_size,
image_token_id=IMAGE_TOKEN_ID,
max_num_tiles=MAX_NUM_TILES,
)

def test_get_image_attention_intervals(self, cross_attn_mask_transform, tokens):
Expand Down Expand Up @@ -89,7 +87,7 @@ def test_inference_call(
sample.update(dummy_kwargs)
actual = cross_attn_mask_transform(sample, inference=True)
expected = [
torch.zeros(len(tokens), image_num_tokens * 2, dtype=torch.bool)
torch.zeros(len(tokens), image_num_tokens, dtype=torch.bool)
for _ in range(len(images))
]
expected[0][2:6, :image_num_tokens] = True
Expand Down
3 changes: 2 additions & 1 deletion torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,8 @@ def padded_collate_tiled_images_and_mask(
if pad_max_images is not None:
_, _, img_seq = concat_masks.shape
concat_masks = F.pad(
concat_masks, (0, pad_max_images * image_seq_len - img_seq)
concat_masks,
(0, pad_max_images * max_num_tiles * tokens_per_tile - img_seq),
)

batch_dict = {
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/llama3_2_vision/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def __init__(
tile_size=tile_size,
patch_size=patch_size,
image_token_id=self.tokenizer.image_id,
max_num_tiles=max_num_tiles,
)

self.stop_tokens = self.tokenizer.stop_tokens
self.max_seq_len = max_seq_len
self.max_num_tiles = max_num_tiles
self.image_seq_len = max_num_tiles * (self.xattn_mask.patches_per_tile + 1)
self.prompt_template = prompt_template
self.pad_id = self.tokenizer.pad_id
Expand Down
13 changes: 2 additions & 11 deletions torchtune/modules/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Mapping, Optional, Protocol
from typing import Any, List, Mapping, Protocol

import torch

Expand Down Expand Up @@ -57,21 +57,17 @@ class VisionCrossAttentionMask(Transform):
E.g. for patch_size = 40, a tile of shape (400, 400) will have 10x10 grid of patches
with shape (40, 40) each.
image_token_id (int): Token ID of the image special token.
max_num_tiles (Optional[int]): Maximum number of tiles in an image, used to
pad mask during inference. Defaults to None
"""

def __init__(
self,
tile_size: int,
patch_size: int,
image_token_id: int,
max_num_tiles: Optional[int] = None,
):
patch_grid_size = tile_size // patch_size
self.patches_per_tile = patch_grid_size**2
self.image_token_id = image_token_id
self.max_num_tiles = max_num_tiles

def _get_image_attention_intervals(self, tokens: List[int]) -> List[List[int]]:
"""
Expand Down Expand Up @@ -163,9 +159,6 @@ def __call__(
# which can vary based on number of tiles since they are not yet tile padded.
# The masks are padded and concatenated together in the batch collator
text_seq_len = len(tokens)
max_image_size = None
if inference and self.max_num_tiles is not None:
max_image_size = self.max_num_tiles * (self.patches_per_tile + 1)
masks = []
for image_num, interval in enumerate(intervals):
# Identify what part of text sequence should be attended
Expand All @@ -178,9 +171,7 @@ def __call__(
# to a single image, so text tokens attend to all the image's tokens.
# The mask is text_seq_len x mask_image_size if defined, otherwise
# it uses current text/image sequence lengths.
mask = torch.zeros(
text_seq_len, max_image_size or image_seq_len, dtype=torch.bool
)
mask = torch.zeros(text_seq_len, image_seq_len, dtype=torch.bool)
mask[start:end, :image_seq_len] = True
masks.append(mask)

Expand Down

0 comments on commit a1bcb97

Please sign in to comment.