Skip to content

Conversation

@orrzohar
Copy link
Contributor

What does this PR do?

SmolVLM2 support

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [yes] Did you read the contributor guideline,
    Pull Request section?
  • [no] Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@merveenoyan
Copy link
Contributor

cc @ArthurZucker fyi this is needed for release otherwise it will be too much work on user's side 🥲

@ArthurZucker ArthurZucker requested a review from molbap February 11, 2025 10:20
Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I added a couple comments! LMK if something is unclear and ping me when you're done with these, I'll re-iterate quickly 🤗
And a first comment, make sure to run the formatter/linter/checker locally:
You'll need to install dev tools within transformers repo

pip install -e .[quality]

And then run this command, that'll cover all the checks to make CI happy

make fixup

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a couple comments for chat_template, let's go!

@orrzohar
Copy link
Contributor Author

orrzohar commented Feb 11, 2025

@molbap @zucchini-nlp
Overall comments:

  1. I have committed a revision that uses modular transformers, as suggested by @molbap. This required a minor edit to modeling_idefics3 to be compatible with the modular workflow.
  2. I have refactored the inputs to the processor to be closer to the standard
  3. I have removed our overwrite of apply_chat_template in favor of a new function that converts the video tokens into the expected sequence of text and image tokens, so we can use the original apply_chat_template
  4. I have updated load_video to include frame_indicies to allow users to pass what frame idx's they are interested in loading. I also created a get_video_details, which fetches video metadata (fps, duration, frame count). We now use load_video rather than having custom video handling logic in smolvlm.

I believe all the comments on PR have been addressed. Let me know if I missed anything/if you have any new comments

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work everyone!
Not super fan of what's happening with the processor / video processor only doing half of the job each, we really need to split the work, keep it simpler:
video processor just returns post processed videos / sampled frames and some metadata, while the processor should merge text and these!

Main question: quid of multiturn

in forward. Instead, we override inputs_merger here with custom logic.
"""

def inputs_merger(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned to @zucchini-nlp , the way you train, whether you use deepspeed or not is irrelevant to transformers. If this is training specifc / data pre-processing it should happen outside the modeling code. Please add a data collator for SmolVlm if you want people to use this for training!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge @molbap 's suggestion here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can just use idefics3 merger then, with modular copying it? Looks pretty similar to me

Comment on lines +237 to +239
if not any(real_images_inds):
# no images, leave one empty image.
real_images_inds[0] = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does inds mean

# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the shape function can help you 😄

Comment on lines 55 to 64
def is_url(val) -> bool:
return isinstance(val, str) and val.startswith("http")


def is_str(val) -> bool:
return isinstance(val, str)


def is_image_or_image_url(elem):
return is_url(elem) or is_valid_image(elem)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a fan of these especially given that we define them in processing utils or image utils where replacing the input chat template happens usually!

}


SmolVLMProcessorKwargs.__annotations__["images_kwargs"] = SmolVLMImagesKwargs # python 3.8 compatibility
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we no longer support 3.8 so not needed

the docstring of this method for more information.
"""
decode_output = self.tokenizer.decode(*args, **kwargs)
return self._regex_to_remove_extra_special_tokens.sub("<image>", decode_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry not sure I undertsand this?

refer to the docstring of this method for more information.
"""
batched_decode_output = self.tokenizer.batch_decode(*args, **kwargs)
return [self._regex_to_remove_extra_special_tokens.sub("<image>", s) for s in batched_decode_output]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neither do I get why we have to do this?

Comment on lines 176 to 177
# Matches one or more occurrences of <row_x_col_y> tags (where x and y are digits, optionally surrounded by newline characters
self._regex_to_remove_extra_special_tokens = re.compile(r"(<row_\d+_col_\d+>\n?)+")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you add them as special tokens you won't go through all the trouble. These tokens are processed by the model anyways no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left this to stay compatible with smolvlm1 -- I think we can drop it if we don't want to support that class here (I will just need to test).

Comment on lines 324 to 331
if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
# When we generate, we don't want to replace the potential image_token_id that we generated by images
# that simply don't exist
inputs_embeds = self.inputs_merger(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
image_hidden_states=image_hidden_states,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question again here, but this disables the multi-turn image processing. Basically what if you want to use a different video/ pass other image?

@molbap
Copy link
Contributor

molbap commented Feb 19, 2025

Pushed a fix for the test_training being unhappy + an attempt at vectorizing the merger that seems to work, would need another eye on it :)

@zucchini-nlp
Copy link
Member

@molbap thanks! I see you also added back the auto-map. The model was removed from MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES intentionally. We should aim to use only the ImageTextToText mapping for VLMs, instead of duplicating over two auto-classes

@molbap
Copy link
Contributor

molbap commented Feb 19, 2025

Yeah, it's just that if the model is not on the correct mappings, test_training fails unless we add manually some labels (else, they are not added by the auto class mapper). + Idefics3 had this double mapping AFAIK?
+1 to removing it though as long as tests pass, esp. training tests!

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 19, 2025

@molbap I think we can add the ImageTextToText mapping in _prepare_for_class to make tests happy, along with CausalMap etc. Idefics3 indeed has both, and all other VLMs too, but it doesn't really make sense to use Vision2Seq for new models. Especially after we released the pipeline and have been promoting the image-text-to-text tag on the hub

@molbap
Copy link
Contributor

molbap commented Feb 19, 2025

alright let me do that! then should be good

@LysandreJik
Copy link
Member

Thanks all!

@LysandreJik LysandreJik merged commit 4397dfc into huggingface:main Feb 20, 2025
19 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants