Skip to content

Commit

Permalink
[VLM][Bugfix] Pass processor kwargs properly on init (vllm-project#13516
Browse files Browse the repository at this point in the history
)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 authored Feb 19, 2025
1 parent 52ce14d commit 377d10b
Show file tree
Hide file tree
Showing 44 changed files with 675 additions and 453 deletions.
1 change: 1 addition & 0 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def load_h2ovl(question: str, image_urls: List[str]) -> ModelRequestData:
trust_remote_code=True,
max_model_len=8192,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"max_dynamic_patch": 4},
)

placeholders = "\n".join(f"Image-{i}: <image>\n"
Expand Down
7 changes: 2 additions & 5 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import ProcessingCache
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
Expand Down Expand Up @@ -42,10 +42,7 @@ def _test_processing_correctness(
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_info.trust_remote_code,
),
tokenizer=cached_tokenizer_from_config(model_config),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
Expand Down
225 changes: 131 additions & 94 deletions tests/models/multimodal/processing/test_h2ovl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for H2OVL's multimodal preprocessing kwargs."""
from typing import Optional
from typing import Mapping, Optional

import pytest
from PIL import Image
from transformers import PretrainedConfig

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from ....conftest import _ImageAssets
from ...utils import build_model_context


def _get_expected_num_patches(
config: PretrainedConfig,
image: Image.Image,
num_imgs: int,
min_num: int,
max_num: int,
):
from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
get_h2ovl_target_ratios)

width, height = image.size

# Calculate the expected number of blocks
if num_imgs == 1 and config.use_msac:
# First pass
blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num=1,
max_num=max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
)

# Second pass
blocks2, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num=3,
max_num=max_num,
prior_aspect_ratio=aspect_ratio,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)

# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0

# Total blocks is the sum of blocks from both passes minus
# overlapping
total_blocks = blocks1 + blocks2 - 1

return total_blocks

blocks, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
expected_num_patches = blocks

if config.use_thumbnail and expected_num_patches > 1:
expected_num_patches += 1

return expected_num_patches


def _run_check(
processor: BaseMultiModalProcessor,
images: list[Image.Image],
min_num: int,
max_num: int,
mm_processor_kwargs: Mapping[str, object],
):
tokenizer = processor.info.get_tokenizer()
config = processor.info.get_hf_config()

mm_data = {"image": images}

total_expected_num_patches = sum(
_get_expected_num_patches(config, image, len(images), min_num, max_num)
for image in images)

processed_inputs = processor.apply("<image>" * len(images), mm_data,
mm_processor_kwargs)

# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape

assert img_tok_count == 256 * total_expected_num_patches
assert pixel_shape[0] == total_expected_num_patches


@pytest.mark.parametrize("model_id", [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
Expand All @@ -25,118 +126,54 @@
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
[4.0, 2.0, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [1, 2, 4, 8])
@pytest.mark.parametrize(
("min_dynamic_patch", "max_dynamic_patch"),
[(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)],
)
@pytest.mark.parametrize("dynamic_image_size", [True, False])
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
size_factors: list[int],
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
num_imgs: int,
kwargs_on_init: bool,
):
from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
get_h2ovl_target_ratios)
mm_processor_kwargs = {
"min_dynamic_patch": min_dynamic_patch,
"max_dynamic_patch": max_dynamic_patch,
"dynamic_image_size": dynamic_image_size,
}

ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": len(size_factors)},
)
tokenizer = cached_tokenizer_from_config(ctx.model_config)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs

config = processor.info.get_hf_config()
use_msac = config.use_msac

mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size

min_num = config.min_dynamic_patch
min_num = min_dynamic_patch if dynamic_image_size else 1
max_num = max_dynamic_patch if dynamic_image_size else 1

# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs

for asset in image_assets:
for factor in size_factors:
image = rescale_image_size(asset.pil_image, factor)
mm_data = {"image": [image] * num_imgs}

width, height = image.size

# Calculate the expected number of blocks
if num_imgs == 1 and use_msac:
# First pass
blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
)

# Second pass
blocks2, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=aspect_ratio,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)

# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0

# Total blocks is the sum of blocks from both passes minus
# overlapping
total_blocks = blocks1 + blocks2 - 1

expected_num_patches = total_blocks
else:
blocks, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
expected_num_patches = blocks

if config.use_thumbnail and expected_num_patches != 1:
expected_num_patches += 1

processed_inputs = processor.apply(prompt, mm_data,
mm_processor_kwargs)
pixel_shape = (
processed_inputs["mm_kwargs"]["pixel_values_flat"].shape)

assert pixel_shape[0] == expected_num_patches * num_imgs
_run_check(
processor,
[
rescale_image_size(image_assets[0].pil_image, f)
for f in size_factors
],
min_num,
max_num,
hf_processor_mm_kwargs,
)
24 changes: 16 additions & 8 deletions tests/models/multimodal/processing/test_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers import Idefics3Config

from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from ....conftest import _ImageAssets
from ...utils import build_model_context
Expand All @@ -22,9 +22,15 @@
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(image_assets: _ImageAssets, model: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int, num_imgs: int):
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,
num_imgs: int,
kwargs_on_init: bool,
):
"""Ensure input_processor_for_idefics3 handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
Expand All @@ -33,15 +39,15 @@ def test_processor_override(image_assets: _ImageAssets, model: str,
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
tokenizer = cached_tokenizer_from_config(ctx.model_config)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
hf_processor = processor.info.get_hf_processor(**mm_processor_kwargs)
hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs

# Build the image str / prompt based on the number of images we pass
placeholders = "<image>" if num_imgs == 1 else "\n".join(
Expand All @@ -54,8 +60,10 @@ def test_processor_override(image_assets: _ImageAssets, model: str,
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}

processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)

# Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs[
"input_ids"][0]
Expand Down
Loading

0 comments on commit 377d10b

Please sign in to comment.