Skip to content

Conversation

@tomaarsen
Copy link
Member

What does this PR do?

  • Add return_dict to get_text_features & get_image_features methods to allow returning 'BaseModelOutputWithPooling'

Fixes #42401

Well, the architectures supporting get_image_features are all extremely different, with wildly different outputs for the get_image_features methods:

  • 2d outputs,
  • 3d outputs,
  • lists of 2d outputs (due to non-matching shapes),
  • existing 'return_attentions' resulting in returning 2-tuple,
  • existing 'return_dict' resulting in returning 3-tuples (???),
  • high quality image embeddings,
  • low quality image embeddings,
  • deepstack image embeddings,
  • etc. etc. etc.

And I only went through like 70-80% of all architectures with get_image_features before I gave up.

Standardisation of all of these sounds like a lost cause. cc @zucchini-nlp I'm curious about your thoughts here. When I did some preliminary research, I only ran into a handful of cases, and I figured we'd be able to reformat them all into one format, but I'm not sure anymore. I added # NOTE: @Tom ... where I figured we might have big problems with standardisation.

For get_text_features it's a lot simpler, there's only one architecture (blip-2) that differs from all others.

I haven't started on get_audio_features and get_video_features, but there's not too much of a point if we can't get get_image_features normalized.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp @ArthurZucker @Cyrilvallez

  • Tom Aarsen

…ModelOutputWithPooling'

Added to all architectures except blip-2, which has a much different structure here. It uses 'Blip2TextModelWithProjection' to get these embeddings/features, but this class isn't as simple to use
…eModelOutputWithPooling'

Well, the architectures supporting get_image_features are all extremely different, with wildly different outputs for the get_image_features methods. 2d outputs, 3d outputs, lists of 2d outputs (due to non-matching shapes), existing 'return_attentions' resulting in returning 2-tuple, existing 'return_dict' resulting in returning 3-tuples (???), high quality image embeddings, low quality image embeddings, deepstack image embeddings, etc. etc. etc.

And I only went through like 70-80% of all architectures with get_image_features before I gave up.

Standardisation of all of these sounds like a lost cause.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this internally and decided to add last_hidden_states to all models as the last state from vision block. The pooled embeddings will stay of different shapes as is

For the last hidden state the shapes are already more standardized, with a few major options. The only special cases might be qwen-like models where each image encoding has different sequence length and thus the outputs are concatenated as length*dim

@tomaarsen
Copy link
Member Author

The initial work on all 4 modalities is done, with a handful of exceptions. There's about 2 or 3 breaking architectures, specifically architectures that already supported return_dict and return_attentions. Typings, docstrings, and tests still have to be added, but I'm curious if this has a chance of being merged before I continue with those.

  • Tom Aarsen

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes, I see there are a few tricky models that do not fit neatly with BaseModelOutput

To wrap it up, to msake this work firstly we need to ensure that all vision encoders are capable of returning dict in the way that PreTrainedModels do, i.e. by checking config,return_dict and returning attentions, hidden states, pooled output etc. Then we can ask get_image_features to return the same dict which was output by an encoder (optionally pooled output is updated in VLMs). That will preserve all fields of the vision encoder output

I think the current state of the PR is already doing it with a few non-standard models. I left comments under those models so lmk if that makes sense

Comment on lines 821 to 827

if return_dict:
return BaseModelOutputWithPooling(
last_hidden_state=hidden_states,
pooler_output=merged_hidden_states,
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally aligned with this, very needed! I think in qwen-like model, the downsample and merging are both part of the multimodal adapter. Usually in vision model the last_hidden_state is the last state after all encoder blocks and before layer norm (e.g. CLIP, SigLIP)

IMO qwen-vision needs the same format

Comment on lines 595 to 597
return_dict (`bool`, *optional*, default to `False`):
Whether to return a `ModelOutput` instead of a pooled embedding.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add complete docs if they were missing for other args

…model_inputs

The changes in check_model_inputs aren't the clearest/prettiest, but they work well for now.
@tomaarsen
Copy link
Member Author

I've pushed a proposal in 9a251ce that takes this in a bit of a different direction by adopting the modern TransformersKwargs and check_model_inputs. I updated the latter to allow setting the pooler_output as the default, unless the user explicitly uses return_dict=True (which returns a ModelOutput subclass) or return_dict=None (which uses the model config's return_dict to determine whether to output a ModelOutput or the pooled embeddings).

I can extend this to more architectures, but want to get your view on this first.

Usage:

from transformers import AutoModel, AutoProcessor
from transformers.image_utils import load_image
import torch

model = AutoModel.from_pretrained("openai/clip-vit-large-patch14", attn_implementation="eager")
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = load_image(url)
image_inputs = processor(images=image, return_tensors="pt")
text_inputs = processor(text=["a photo of a cat"], return_tensors="pt")
joint_inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt")

def print_output(output):
    if isinstance(output, torch.Tensor):
        print("Output is a tensor with shape:", output.shape)
    else:
        print("Output is a ModelOutput with attributes:")
        for key, value in output.items():
            if isinstance(value, torch.Tensor):
                print(f"  {key}: tensor with shape {value.shape}")
            else:
                print(f"  {key}: {type(value)}")
    print()

with torch.inference_mode():
    image_features = model.get_image_features(**image_inputs)
    print("model.get_image_features(**image_inputs) outputs:")
    print_output(image_features)

    image_features = model.get_image_features(**image_inputs, return_dict=True)
    print("model.get_image_features(**image_inputs, return_dict=True) outputs:")
    print_output(image_features)

    image_features = model.get_image_features(**image_inputs, return_dict=True, output_hidden_states=True, output_attentions=True)
    print("model.get_image_features(**image_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:")
    print_output(image_features)

    text_features = model.get_text_features(**text_inputs)
    print("model.get_text_features(**text_inputs) outputs:")
    print_output(text_features)

    text_features = model.get_text_features(**text_inputs, return_dict=True)
    print("model.get_text_features(**text_inputs, return_dict=True) outputs:")
    print_output(text_features)

    text_features = model.get_text_features(**text_inputs, return_dict=True, output_hidden_states=True, output_attentions=True)
    print("model.get_text_features(**text_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:")
    print_output(text_features)

Outputs:

model.get_image_features(**image_inputs) outputs:
Output is a tensor with shape: torch.Size([1, 768])

model.get_image_features(**image_inputs, return_dict=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 257, 1024])
  pooler_output: tensor with shape torch.Size([1, 768])

model.get_image_features(**image_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 257, 1024])
  pooler_output: tensor with shape torch.Size([1, 768])
  hidden_states: <class 'tuple'>
  attentions: <class 'tuple'>

model.get_text_features(**text_inputs) outputs:
Output is a tensor with shape: torch.Size([1, 768])

model.get_text_features(**text_inputs, return_dict=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 7, 768])
  pooler_output: tensor with shape torch.Size([1, 768])

model.get_text_features(**text_inputs, return_dict=True, output_hidden_states=True, output_attentions=True) outputs:
Output is a ModelOutput with attributes:
  last_hidden_state: tensor with shape torch.Size([1, 7, 768])
  pooler_output: tensor with shape torch.Size([1, 768])
  hidden_states: <class 'tuple'>
  attentions: <class 'tuple'>
  • Tom Aarsen

….._features methods

This commit updates all get_text_features methods, even blip_2, which was previously not yet attempted
A handful of outliers that aren't updated yet, e.g. if there's 2+ ModelOutput classes that are viable, or the vq-based ones

For context, the other modeling file classes haven't been updated with the new get_..._features format, nor have the tests
@tomaarsen
Copy link
Member Author

tomaarsen commented Dec 16, 2025

For context, these are the TODOs at this point:

  • Unfinished architectures
    • fuyu get_image_features: I don't think Fuyu has a real Vision Encoder beyond just a single Linear
    • blip_2 get_image_features: The new format misses the query_outputs/qformer_outputs, should use a new ModelOutput subclass somehow.
    • instructblip get_image_features: The new format misses the query_outputs/qformer_outputs, should use a new ModelOutput subclass somehow.
    • instructblipvideo get_video_features: See above
    • kosmos2 get_image_features: The new format misses the projection_attentions, should use a new ModelOutput subclass somehow.
    • ovis2 get_image_features: The new format misses the visual_indicator_features, should use a new ModelOutput subclass somehow.
    • deepseek_vl_hybrid get_image_features: This method produces both low_res_vision_encodings and high_res_vision_encodings, should use a new ModelOutput subclass somehow to combine them.
    • chameleon get_image_features: Update the VQVAE class to output the hidden states before quantization.
    • emu3 get_image_features: Update the VQVAE class to output the hidden states before quantization.
  • Update all architecture classes to accept the new output format
  • Add and/or update tests for the new output format
  • Update docstrings
  • Update type hints

  • Tom Aarsen

The Fuyu architecture doesn't have an image encoder:
> Architecturally, Fuyu is a vanilla decoder-only transformer - there is no image encoder.
@tomaarsen
Copy link
Member Author

I introduced a ModelOutput in f082a8e for Chameleon, although there's many different approaches we can take there. For example, the quantized_last_hidden_state and emb_loss in ChameleonVQVAE.encode is never used, but I've chosen to still return it, although it's unusual to return a loss in a ModelOutput like this. I'm curious about your thoughts on this one @zucchini-nlp. If it seems alright, then I can (presumably) copy the approach to Emu3 which uses a similar VQVAE (although that one luckily only outputs the image_tokens currently, to which I can add the last_hidden_state).

  • Tom Aarsen

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: aimv2, align, altclip, aria, audioflamingo3, aya_vision, blip, blip_2, chameleon, chinese_clip, clap, clip, clipseg, clvp, cohere2_vision, colqwen2

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=42564&sha=7af0b6

@tomaarsen tomaarsen marked this pull request as ready for review December 18, 2025 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

The get_(text|image|audio|video)_features methods have inconsistent output formats, needs aligning for Sentence Transformers

3 participants