Skip to content

[VLM] Support Pixtral-HF on V1 #14275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `PixtralForConditionalGeneration`
* Pixtral
* T + I<sup>+</sup>
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b` (see note), etc.
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc.
*
* ✅︎
* ✅︎
Expand Down Expand Up @@ -930,10 +930,6 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
Currently the PaliGemma model series is implemented without PrefixLM attention mask. This model series may be deprecated in a future release.
:::

:::{note}
`mistral-community/pixtral-12b` does not support V1 yet.
:::

:::{note}
To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`.
:::
Expand Down
171 changes: 164 additions & 7 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, TypeVar, Union)
TypedDict, TypeVar, Union, cast)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -35,6 +35,7 @@
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, flatten_2d_lists, json_map_leaves

from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
Expand All @@ -56,6 +57,25 @@ class LlavaImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""

feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.

Shape: `(batch_size, num_crops, num_patch)`
"""

embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_embeds)`
"""

num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""


class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
Expand All @@ -65,6 +85,25 @@ class LlavaImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""

feat_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image features correspond
to patch tokens.

Shape: `(batch_size, num_crops, num_patch)`
"""

embed_is_patch: Union[torch.Tensor, List[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.

Shape: `(batch_size, num_embeds)`
"""

num_crops: torch.Tensor
"""Shape: `(batch_size, num_images)`"""


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]

Expand Down Expand Up @@ -317,14 +356,40 @@ def _call_hf_processor(
for p, (h, w) in zip(pixel_values, image_sizes)
]

hf_config = self.info.get_hf_config()

tile_sizes = [
get_pixtral_hf_image_feature_grid_size(
hf_config.vision_config,
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2])
for pixel_value in processed_outputs["pixel_values"]
]
num_crops = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# flatten the list and later use `num_crops` to get per-image masks.
embed_is_patch = torch.tensor(
flatten_2d_lists([([True] * ncols + [False]) * nrows
for ncols, nrows in tile_sizes]))
processed_outputs["num_crops"] = num_crops
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["feat_is_patch"] = embed_is_patch

return processed_outputs

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0)).view(-1)
return dict(
feat_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
embed_is_patch=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
Expand Down Expand Up @@ -562,6 +627,23 @@ def _parse_and_validate_image_input(
if pixel_values is None and image_embeds is None:
return None

feat_is_patch = kwargs.pop("feat_is_patch", None)
if feat_is_patch is not None and not isinstance(
feat_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of feat_is_patch. "
f"Got type: {type(feat_is_patch)}")

embed_is_patch = kwargs.pop("embed_is_patch", None)
if embed_is_patch is not None and not isinstance(
embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")

num_crops = kwargs.pop("num_crops", None)
if num_crops is not None and not isinstance(num_crops, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")

if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
Expand All @@ -571,12 +653,18 @@ def _parse_and_validate_image_input(
return LlavaImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

if image_embeds is not None:
Expand All @@ -587,6 +675,9 @@ def _parse_and_validate_image_input(
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
feat_is_patch=feat_is_patch,
embed_is_patch=embed_is_patch,
num_crops=num_crops,
)

raise AssertionError("This line should be unreachable.")
Expand Down Expand Up @@ -633,16 +724,74 @@ def _process_image_input(self,

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)

def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...]]:
if isinstance(image_features, torch.Tensor):
return self.multi_modal_projector(image_features)

feature_sizes = [
image_feature.shape[0] for image_feature in image_features
]

image_embeds = self.multi_modal_projector(torch.cat(image_features))
image_embeds = torch.split(image_embeds, feature_sizes)
return image_embeds

def _get_mm_embeds(
self,
features: torch.Tensor, # Shape: (num_crop, num_patch, d)
feat_is_patch: torch.Tensor, # Shape: (num_crop, num_patch)
num_crops: torch.Tensor, # Shape: (num_images,)
embed_is_patch: torch.Tensor, # Shape: (num_embeds,)
) -> list[torch.Tensor]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.

Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""

# Insert columns of nan values according to `feat_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
feat_is_patch = feat_is_patch.view(-1)
embed_is_patch = embed_is_patch.view(-1)
expanded_embedding = torch.full(
(sum(num_crops), *features.shape[1:]),
torch.nan,
dtype=features.dtype).to(features.device)
expanded_embedding[feat_is_patch] = features

num_crops_per_image = num_crops.tolist()
feats_per_image = expanded_embedding.split(num_crops_per_image)
f_is_patch_per_image = feat_is_patch.split(num_crops_per_image)

embed_dim = expanded_embedding.shape[-1]
num_embeds = embed_is_patch.shape[0]

embeds_in_batch = list[torch.Tensor]()
for feats, f_is_patch in zip(feats_per_image, f_is_patch_per_image):
embeds = feats.new_full((num_embeds, embed_dim), torch.nan)
embeds[embed_is_patch] = feats[f_is_patch]
embeds_in_batch.append(embeds)

return embeds_in_batch

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
if kwargs.get("v0_path", False):
return vision_embeddings
else:
nested_emb = [
self._get_mm_embeds(*args) for args in zip(
vision_embeddings, image_input["feat_is_patch"],
image_input["num_crops"], image_input["embed_is_patch"])
]
return flatten_2d_lists(nested_emb)

def get_input_embeddings(
self,
Expand All @@ -651,8 +800,15 @@ def get_input_embeddings(
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# Extract the patch tokens
patch_embeddings = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
input_ids, inputs_embeds, cast(NestedTensors,
patch_embeddings),
self.config.image_token_index)
return inputs_embeds

Expand Down Expand Up @@ -705,6 +861,7 @@ def forward(
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,8 +1484,8 @@ def _parse_and_validate_image_input(

img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
raise ValueError("Incorrect type of img_patch_id. "
f"Got type: {type(img_patch_id)}")
self.img_patch_id = img_patch_id.flatten().unique().item()

return MolmoImageInputs(
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,9 +1042,13 @@ def forward(
for img in pixel_values
]

patch_embeds = [
p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list
]
embed_sizes = [p.shape[1] for p in patch_embeds]

# flatten to a single sequence
patch_embeds = torch.cat(
[p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
patch_embeds = torch.cat(patch_embeds, dim=1)
patch_embeds = self.ln_pre(patch_embeds)

# positional embeddings
Expand Down Expand Up @@ -1075,6 +1079,8 @@ def forward(
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)

# squeeze dim 0 and split into separate tensors for each image
out = torch.split(torch.squeeze(out), embed_sizes)
return out

# (TODO) Add prefix argument for filtering out weights to be loaded
Expand Down