Skip to content

Commit 75404d0

Browse files
[VLM] Update compatibility with transformers 4.49
1 parent bf3b79e commit 75404d0

File tree

9 files changed

+50
-59
lines changed

9 files changed

+50
-59
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,7 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
883883
:::
884884

885885
:::{note}
886-
The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingface.co/mistral-community/pixtral-12b/discussions/22)).
887-
A corrected version is available at <gh-file:examples/template_pixtral_hf.jinja>.
886+
`mistral-community/pixtral-12b` does not support V1 yet.
888887
:::
889888

890889
:::{note}

examples/template_pixtral_hf.jinja

Lines changed: 0 additions & 38 deletions
This file was deleted.

tests/entrypoints/test_chat_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,6 @@ def test_resolve_content_format_hf_defined(model, expected_format):
761761
("template_falcon.jinja", "string"),
762762
("template_inkbot.jinja", "string"),
763763
("template_llava.jinja", "string"),
764-
("template_pixtral_hf.jinja", "openai"),
765764
("template_vlm2vec.jinja", "openai"),
766765
("tool_chat_template_granite_20b_fc.jinja", "string"),
767766
("tool_chat_template_hermes.jinja", "string"),

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@
224224
marks=[
225225
pytest.mark.skipif(
226226
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
227-
reason="HF model is not compatible with transformers>=4.48.0",
227+
reason="HF model is not compatible with transformers>=4.48",
228228
)
229229
],
230230
),
@@ -359,7 +359,7 @@
359359
marks=[
360360
pytest.mark.skipif(
361361
Version(TRANSFORMERS_VERSION) >= Version("4.48"),
362-
reason="HF model is not compatible with transformers>=4.48.0",
362+
reason="HF model is not compatible with transformers>=4.48",
363363
)
364364
],
365365
),

tests/models/embedding/vision_language/test_llava_next.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import pytest
66
import torch.nn.functional as F
7-
import transformers
87
from transformers import AutoModelForVision2Seq
98

109
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
@@ -57,6 +56,10 @@ def _run_test(
5756

5857
with hf_runner(model, dtype=dtype,
5958
auto_cls=AutoModelForVision2Seq) as hf_model:
59+
# Patch the issue where generation_config.json is missing
60+
hf_model.processor.patch_size = \
61+
hf_model.model.config.vision_config.patch_size
62+
6063
# Patch the issue where image_token_id
6164
# exceeds the maximum allowed vocab size
6265
hf_model.model.resize_token_embeddings(
@@ -88,8 +91,6 @@ def _run_test(
8891
)
8992

9093

91-
@pytest.mark.skipif(transformers.__version__ >= "4.46",
92-
reason="Model broken with changes in transformers 4.46")
9394
@pytest.mark.core_model
9495
@pytest.mark.parametrize("model", MODELS)
9596
@pytest.mark.parametrize("dtype", ["half"])

vllm/model_executor/models/llava.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -293,16 +293,29 @@ def _call_hf_processor(
293293

294294
pixel_values = processed_outputs.get("pixel_values")
295295
if pixel_values is not None:
296-
images = mm_data["images"]
297-
assert isinstance(images, list)
298-
299-
# Original output: (1, num_images, C, H, W)
300-
# New output: (num_images, C, H, W)
301-
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
302-
assert (isinstance(pixel_values[0], list)
303-
and len(pixel_values[0]) == len(images))
304-
305-
processed_outputs["pixel_values"] = pixel_values[0]
296+
# Before/after https://github.com/huggingface/transformers/pull/35122
297+
if Version(TRANSFORMERS_VERSION) <= Version("4.48.2"):
298+
images = mm_data["images"]
299+
assert isinstance(images, list)
300+
301+
# Original output: (1, num_images, C, H, W)
302+
# New output: (num_images, C, H, W)
303+
assert (isinstance(pixel_values, list)
304+
and len(pixel_values) == 1)
305+
assert (isinstance(pixel_values[0], list)
306+
and len(pixel_values[0]) == len(images))
307+
308+
processed_outputs["pixel_values"] = pixel_values[0]
309+
else:
310+
# Avoid padding since we need the output for each image to be
311+
# independent of other images for the cache to work correctly
312+
image_sizes = processed_outputs["image_sizes"]
313+
assert len(pixel_values) == len(image_sizes)
314+
315+
processed_outputs["pixel_values"] = [
316+
p[:, :h, :w]
317+
for p, (h, w) in zip(pixel_values, image_sizes)
318+
]
306319

307320
return processed_outputs
308321

vllm/model_executor/models/llava_next.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,15 @@ def get_hf_config(self) -> LlavaNextLikeConfig:
7373
return self.ctx.get_hf_config(LlavaNextConfig)
7474

7575
def get_hf_processor(self):
76-
return self.ctx.get_hf_processor(LlavaNextProcessor)
76+
hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor)
77+
78+
# In case patch_size is omitted from `processor_config.json`
79+
# e.g. for E5-V: https://huggingface.co/royokong/e5-v
80+
if hf_processor.patch_size is None:
81+
patch_size = self.get_vision_encoder_info().get_patch_size()
82+
hf_processor.patch_size = patch_size
83+
84+
return hf_processor
7785

7886
# Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
7987
def get_num_image_tokens(

vllm/model_executor/models/minicpmv.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,15 @@ def get_hf_processor(
342342
**kwargs: object,
343343
):
344344
hf_processor = self.ctx.get_hf_processor()
345+
346+
# NumPy arrays are considered as Iterable but not Sequence in
347+
# https://github.com/huggingface/transformers/blob/main/src/transformers/image_transforms.py#L428
348+
image_processor = hf_processor.image_processor # type: ignore
349+
for attr in ("mean", "std"):
350+
val = getattr(image_processor, attr)
351+
if isinstance(val, np.ndarray):
352+
setattr(image_processor, attr, val.tolist())
353+
345354
return hf_processor
346355

347356
def get_image_processor(self):

vllm/multimodal/inputs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ class PlaceholderRange(TypedDict):
141141
def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
142142
"""Equality check between :data:`NestedTensors` objects."""
143143
if isinstance(a, torch.Tensor):
144-
return isinstance(b, torch.Tensor) and bool((a == b).all().item())
144+
return isinstance(b, torch.Tensor) and torch.equal(a, b)
145145
elif isinstance(b, torch.Tensor):
146-
return isinstance(a, torch.Tensor) and bool((b == a).all().item())
146+
return isinstance(a, torch.Tensor) and torch.equal(b, a)
147147

148148
if isinstance(a, list):
149149
return (isinstance(b, list)

0 commit comments

Comments
 (0)