Skip to content

[Fix] ViViT interpolate_pos_encoding #33815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def __init__(self, config):
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.tubelet_size[1:]
self.config = config

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
# Adapted from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
Expand All @@ -129,8 +130,8 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:

dim = embeddings.shape[-1]

new_height = height // self.patch_size
new_width = width // self.patch_size
new_height = height // self.patch_size[0]
new_width = width // self.patch_size[1]

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/vivit/test_modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ def test_inference_interpolate_pos_encoding(self):
# allowing to interpolate the pre-trained position embeddings in order to use
# the model on higher resolutions. The DINO model by Facebook AI leverages this
# to visualize self-attention on higher resolution images.
model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device)
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400").to(torch_device)

image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the checkpoint and the crop_size in the test?

Copy link
Contributor Author

@RUFFY-369 RUFFY-369 Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts

  • Changing the checkpoint was a test to see if that fixes the test failure in PR Slow CI that you triggered. (The one I mentioned above)
  • different crop_size(s) leads to an error during the calling of interpolation method for example when the crop_size was crop_size={"height": 480, "width": 480} the following error occurs:
# add positional encoding to each token
        if interpolate_pos_encoding:
>           embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
E           RuntimeError: The size of tensor a (14401) must match the size of tensor b (901) at non-singleton dimension 1

same happens with some other crop sizes as well. But the error doesn't occur for crop_size like 232 or 228 or even the default crop_size 224

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks for explaining. The error shouldn't be triggered for the default crop_size value (no interpolation should happen) but if it works for these none default values then it's all good :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default crop_size value is 224, right?!, in all the image processing files. The error doesn't occur for that value tho. This value is only given in the test file so that the error doesn't occur & for the sake of testing a value apart from the default one.

video = prepare_video()
inputs = image_processor(
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crop_size option should still be included in the test, as this will force the interpolation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me push the commit in a second

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 👍

video, size={"shortest_edge": 480}, crop_size={"height": 232, "width": 232}, return_tensors="pt"
)
pixel_values = inputs.pixel_values.to(torch_device)

Expand Down
Loading