Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] ViViT interpolate_pos_encoding #33815

Merged

Conversation

RUFFY-369
Copy link
Contributor

What does this PR do?

Fixes #33814

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts

@RUFFY-369 RUFFY-369 mentioned this pull request Sep 30, 2024
5 tasks
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Main comment is about the test

@@ -363,9 +363,7 @@ def test_inference_interpolate_pos_encoding(self):

image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
video = prepare_video()
inputs = image_processor(
video, size={"shortest_edge": 480}, crop_size={"height": 480, "width": 480}, return_tensors="pt"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crop_size option should still be included in the test, as this will force the interpolation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, let me push the commit in a second

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 👍

@@ -104,9 +104,10 @@ def __init__(self, config):
torch.zeros(1, self.patch_embeddings.num_patches + 1, config.hidden_size)
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.patch_size = config.tubelet_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to set this correctly to a patch_size i.e. a tuple of len 2, rather than assign it the tublet size

Suggested change
self.patch_size = config.tubelet_size
self.patch_size = config.tubelet_size[1:]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed the changes 👍

@RUFFY-369
Copy link
Contributor Author

@amyeroberts All green and suggestions are pushed.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@RUFFY-369
Copy link
Contributor Author

@amyeroberts Weirdly the slow test is passing locally but in the PR Slow CI that you started it is throwing out the following error:

=================================== FAILURES ===================================
______ VivitModelIntegrationTest.test_inference_interpolate_pos_encoding _______

checkpoint_file = '/mnt/cache/hub/models--google--vivit-b-16x2/snapshots/fc341053d36b42d446b3ffccdbd52452712a23f3/pytorch_model.bin'
is_quantized = False

    def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False):
        """
        Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
        """
        if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
            # Check format of the archive
            with safe_open(checkpoint_file, framework="pt") as f:
                metadata = f.metadata()
            if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
                raise OSError(
                    f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
                    "you save your model with the `save_pretrained` method."
                )
            return safe_load_file(checkpoint_file)
        try:
            if (
                (is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0)
                or (is_fsdp_enabled() and not is_local_dist_rank_0())
            ) and not is_quantized:
                map_location = "meta"
            else:
                map_location = "cpu"
            extra_args = {}
            # mmap can only be used with files serialized with zipfile-based format.
            if (
                isinstance(checkpoint_file, str)
                and map_location != "meta"
                and version.parse(torch.__version__) >= version.parse("2.1.0")
                and is_zipfile(checkpoint_file)
            ):
                extra_args = {"mmap": True}
            weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
>           return torch.load(
                checkpoint_file,
                map_location=map_location,
                **weights_only_kwarg,
                **extra_args,
            )

src/transformers/modeling_utils.py:575: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.8/dist-packages/torch/serialization.py:1066: in load
    if _is_zipfile(opened_file):
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

f = <_io.BufferedReader name='/mnt/cache/hub/models--google--vivit-b-16x2/snapshots/fc341053d36b42d446b3ffccdbd52452712a23f3/pytorch_model.bin'>

    def _is_zipfile(f) -> bool:
        # This is a stricter implementation than zipfile.is_zipfile().
        # zipfile.is_zipfile() is True if the magic number appears anywhere in the
        # binary. Since we expect the files here to be generated by torch.save or
        # torch.jit.save, it's safe to only check the start bytes and avoid
        # collisions and assume the zip has only 1 file.
        # See bugs.python.org/issue28494.
    
        start = f.tell()
        # Read the first few bytes and match against the ZIP file signature
        local_header_magic_number = b'PK\x03\x04'
>       read_bytes = f.read(len(local_header_magic_number))
E       OSError: [Errno 5] Input/output error

/usr/local/lib/python3.8/dist-packages/torch/serialization.py:[206](https://github.com/huggingface/transformers/actions/runs/11105958398/job/30854864070#step:11:207): OSError

During handling of the above exception, another exception occurred:

self = <tests.models.vivit.test_modeling_vivit.VivitModelIntegrationTest testMethod=test_inference_interpolate_pos_encoding>

    @slow
    def test_inference_interpolate_pos_encoding(self):
        # Vivit models have an `interpolate_pos_encoding` argument in their forward method,
        # allowing to interpolate the pre-trained position embeddings in order to use
        # the model on higher resolutions. The DINO model by Facebook AI leverages this
        # to visualize self-attention on higher resolution images.
>       model = VivitModel.from_pretrained("google/vivit-b-16x2").to(torch_device)

tests/models/vivit/test_modeling_vivit.py:362: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
src/transformers/modeling_utils.py:3814: in from_pretrained
    state_dict = load_state_dict(resolved_archive_file)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

checkpoint_file = '/mnt/cache/hub/models--google--vivit-b-16x2/snapshots/fc341053d36b42d446b3ffccdbd52452712a23f3/pytorch_model.bin'
is_quantized = False

    def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False):
        """
        Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
        """
        if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
            # Check format of the archive
            with safe_open(checkpoint_file, framework="pt") as f:
                metadata = f.metadata()
            if metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
                raise OSError(
                    f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
                    "you save your model with the `save_pretrained` method."
                )
            return safe_load_file(checkpoint_file)
        try:
            if (
                (is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0)
                or (is_fsdp_enabled() and not is_local_dist_rank_0())
            ) and not is_quantized:
                map_location = "meta"
            else:
                map_location = "cpu"
            extra_args = {}
            # mmap can only be used with files serialized with zipfile-based format.
            if (
                isinstance(checkpoint_file, str)
                and map_location != "meta"
                and version.parse(torch.__version__) >= version.parse("2.1.0")
                and is_zipfile(checkpoint_file)
            ):
                extra_args = {"mmap": True}
            weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
            return torch.load(
                checkpoint_file,
                map_location=map_location,
                **weights_only_kwarg,
                **extra_args,
            )
        except Exception as e:
            try:
                with open(checkpoint_file) as f:
>                   if f.read(7) == "version":
E                   OSError: [Errno 5] Input/output error

src/transformers/modeling_utils.py:584: OSError

I pushed a commit which could fix it but still a little doubtful

@HuggingFaceDocBuilderDev

Hey! 🤗 Thanks for your contribution to the transformers library!

Before merging this pull request, slow tests CI should be triggered. To enable this:

  • Add the run-slow label to the PR
  • When your PR is ready for merge and all reviewers' comments have been addressed, push an empty commit with the command [run-slow] followed by a comma separated list of all the models to be tested, i.e. [run_slow] model_to_test_1, model_to_test_2
    • If the pull request affects a lot of models, put at most 10 models in the commit message
  • A transformers maintainer will then approve the workflow to start the tests

(For maintainers) The documentation for slow tests CI on PRs is here.


image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2")
image_processor = VivitImageProcessor.from_pretrained("google/vivit-b-16x2-kinetics400")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change the checkpoint and the crop_size in the test?

Copy link
Contributor Author

@RUFFY-369 RUFFY-369 Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts

  • Changing the checkpoint was a test to see if that fixes the test failure in PR Slow CI that you triggered. (The one I mentioned above)
  • different crop_size(s) leads to an error during the calling of interpolation method for example when the crop_size was crop_size={"height": 480, "width": 480} the following error occurs:
# add positional encoding to each token
        if interpolate_pos_encoding:
>           embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
E           RuntimeError: The size of tensor a (14401) must match the size of tensor b (901) at non-singleton dimension 1

same happens with some other crop sizes as well. But the error doesn't occur for crop_size like 232 or 228 or even the default crop_size 224

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, thanks for explaining. The error shouldn't be triggered for the default crop_size value (no interpolation should happen) but if it works for these none default values then it's all good :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default crop_size value is 224, right?!, in all the image processing files. The error doesn't occur for that value tho. This value is only given in the test file so that the error doesn't occur & for the sake of testing a value apart from the default one.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@amyeroberts amyeroberts merged commit 68a2b50 into huggingface:main Oct 1, 2024
18 checks passed
@RUFFY-369 RUFFY-369 deleted the fix_interpolate_pos_encoding branch October 1, 2024 19:18
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
* fix:test_inference_interpolate_pos_encoding

* style:make style;make fixup

* test: add suggestion to test_modeling_vivit

* chore:add suggestions

* style:make style

* [run_slow] vivit

* ci:slow test fix

* [run_slow] vivit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[run_slow test fail] ViViT test_inference_interpolate_pos_encoding
3 participants