Skip to content

Commit c99a52f

Browse files
alex-jw-brooksDarkLight1337
authored andcommitted
[Model] Expose Phi3v num_crops as a mm_processor_kwarg (vllm-project#8658)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
1 parent 4476c26 commit c99a52f

File tree

4 files changed

+230
-14
lines changed

4 files changed

+230
-14
lines changed

examples/offline_inference_vision_language.py

+14
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,24 @@ def run_phi3v(question, modality):
8383

8484
# In this example, we override max_num_seqs to 5 while
8585
# keeping the original context length of 128k.
86+
87+
# num_crops is an override kwarg to the multimodal image processor;
88+
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
89+
# to use 16 for single frame scenarios, and 4 for multi-frame.
90+
#
91+
# Generally speaking, a larger value for num_crops results in more
92+
# tokens per image instance, because it may scale the image more in
93+
# the image preprocessing. Some references in the model docs and the
94+
# formula for image tokens after the preprocessing
95+
# transform can be found below.
96+
#
97+
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
98+
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
8699
llm = LLM(
87100
model="microsoft/Phi-3-vision-128k-instruct",
88101
trust_remote_code=True,
89102
max_num_seqs=5,
103+
mm_processor_kwargs={"num_crops": 16},
90104
)
91105
stop_token_ids = None
92106
return llm, prompt, stop_token_ids

examples/offline_inference_vision_language_multi_image.py

+13
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,24 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData:
6767

6868

6969
def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData:
70+
# num_crops is an override kwarg to the multimodal image processor;
71+
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
72+
# to use 16 for single frame scenarios, and 4 for multi-frame.
73+
#
74+
# Generally speaking, a larger value for num_crops results in more
75+
# tokens per image instance, because it may scale the image more in
76+
# the image preprocessing. Some references in the model docs and the
77+
# formula for image tokens after the preprocessing
78+
# transform can be found below.
79+
#
80+
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
81+
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
7082
llm = LLM(
7183
model="microsoft/Phi-3.5-vision-instruct",
7284
trust_remote_code=True,
7385
max_model_len=4096,
7486
limit_mm_per_prompt={"image": len(image_urls)},
87+
mm_processor_kwargs={"num_crops": 4},
7588
)
7689
placeholders = "\n".join(f"<|image_{i}|>"
7790
for i, _ in enumerate(image_urls, start=1))

tests/models/decoder_only/vision_language/test_phi3v.py

+181-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
import os
22
import re
3-
from typing import List, Optional, Tuple, Type
3+
from typing import Callable, List, Optional, Tuple, Type
44

55
import pytest
6-
from transformers import AutoTokenizer
6+
import torch
7+
from transformers import AutoImageProcessor, AutoTokenizer
78

9+
from vllm.inputs import InputContext, LLMInputs
10+
from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
11+
from vllm.multimodal import MultiModalRegistry
812
from vllm.multimodal.utils import rescale_image_size
913
from vllm.sequence import SampleLogprobs
1014
from vllm.utils import is_cpu, is_hip
1115

12-
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
13-
from ...utils import check_logprobs_close
16+
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
17+
_ImageAssets)
18+
from ...utils import build_model_context, check_logprobs_close
1419

1520
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
1621
"stop_sign":
@@ -71,7 +76,7 @@ def run_test(
7176
7277
All the image fixtures for the test are from IMAGE_ASSETS.
7378
For huggingface runner, we provide the PIL images as input.
74-
For vllm runner, we provide MultiModalDataDict objects
79+
For vllm runner, we provide MultiModalDataDict objects
7580
and corresponding MultiModalConfig as input.
7681
Note, the text input is also adjusted to abide by vllm contract.
7782
The text output is sanitized to be able to compare with hf.
@@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
230235
mm_limit=2,
231236
tensor_parallel_size=1,
232237
)
238+
239+
240+
### Fast tests for correctness in processor_kwarg override handling
241+
242+
243+
# Wrap lazy imports to avoid initializing CUDA during test collection
244+
@pytest.fixture()
245+
def input_processor_for_phi3v():
246+
from vllm.model_executor.models.phi3v import input_processor_for_phi3v
247+
return input_processor_for_phi3v
248+
249+
250+
@pytest.fixture()
251+
def dummy_data_for_phi3v():
252+
from vllm.model_executor.models.phi3v import dummy_data_for_phi3v
253+
return dummy_data_for_phi3v
254+
255+
256+
@pytest.fixture()
257+
def get_max_phi3v_image_tokens():
258+
from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens
259+
return get_max_phi3v_image_tokens
260+
261+
262+
@pytest.mark.parametrize("model", models)
263+
@pytest.mark.parametrize("num_crops", [4, 16, None])
264+
def test_input_mapper_override(model: str, image_assets: _ImageAssets,
265+
num_crops: Optional[int]):
266+
"""Ensure that the [default] input mapper handles num_crops properly."""
267+
# We pass the processor kwargs here since for this model, we fall back to
268+
# the default mapper; this will fall back to the HF mapper and forward
269+
# mm_processor_kwargs to it.
270+
mm_processor_kwargs = {
271+
"num_crops": num_crops
272+
} if num_crops is not None else {}
273+
ctx = build_model_context(
274+
model_name=model,
275+
tokenizer_name=model,
276+
trust_remote_code=True,
277+
mm_processor_kwargs=mm_processor_kwargs,
278+
)
279+
280+
hf_processor = AutoImageProcessor.from_pretrained(model,
281+
trust_remote_code=True,
282+
**mm_processor_kwargs)
283+
284+
mm_registry = MultiModalRegistry()
285+
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
286+
287+
image = image_assets[0].pil_image
288+
hf_result = hf_processor.preprocess(
289+
image,
290+
return_tensors="pt",
291+
)
292+
293+
vllm_result = mm_registry.map_input(
294+
ctx.model_config,
295+
{"image": image},
296+
)
297+
298+
assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"])
299+
assert torch.all(
300+
hf_result["num_img_tokens"] == vllm_result["num_img_tokens"])
301+
302+
# For pixel values, the second axis should be the num_crops + 1
303+
# for the rescaled original image. The default value in VLLM falls
304+
# back to the HF config, which is why we compare to the processor num_crops
305+
assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"])
306+
assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1
307+
308+
309+
@pytest.mark.parametrize("model", models)
310+
@pytest.mark.parametrize("num_crops,expected_max_tokens", [
311+
(4, 781),
312+
(16, 2653),
313+
])
314+
def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str,
315+
num_crops: int, expected_max_tokens: int):
316+
"""Ensure get_max_phi3v_image_tokens handles num_crops properly."""
317+
# NOTE: mm_processor_kwargs on the context in this test is unused, since
318+
# this is testing the mapper directly. In practice, the processor kwargs
319+
# are wrapped in a closure when calling the max tokens func. We explicitly
320+
# do NOT use the mm_processor_kwargs in the model context here to ensure
321+
# that the max image tokens implementation is referencing a mix of the
322+
# kwargs to the function and the original mm_processor_kwargs in case
323+
# values are somehow updated and end up in a bad state.
324+
ctx = build_model_context(
325+
model_name=model,
326+
tokenizer_name=model,
327+
trust_remote_code=True,
328+
mm_processor_kwargs=None,
329+
)
330+
331+
actual_max_tokens = get_max_phi3v_image_tokens(
332+
InputContext(ctx.model_config),
333+
num_crops=num_crops,
334+
)
335+
336+
assert expected_max_tokens == actual_max_tokens
337+
338+
339+
@pytest.mark.parametrize("model", models)
340+
@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [
341+
(4, 781, 1),
342+
(4, 781, 2),
343+
(16, 2653, 1),
344+
(16, 2653, 2),
345+
])
346+
def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str,
347+
num_crops: int, toks_per_img: int, num_imgs: int):
348+
"""Ensure dummy_data_for_phi3v handles num_crops properly."""
349+
# Same as the previous test - don't initialize mm_processor_kwargs
350+
# in this test and assume that the kwargs will be correctly expanded by
351+
# the partial when calling the dummy data func.
352+
ctx = build_model_context(
353+
model_name=model,
354+
tokenizer_name=model,
355+
trust_remote_code=True,
356+
mm_processor_kwargs=None,
357+
)
358+
359+
sequence_data, _, = dummy_data_for_phi3v(
360+
ctx=ctx,
361+
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
362+
mm_counts={"image": num_imgs},
363+
num_crops=num_crops,
364+
)
365+
# Ensure we have the right number of placeholders per num_crops size
366+
img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID)
367+
assert img_tok_count == toks_per_img * num_imgs
368+
369+
370+
@pytest.mark.parametrize("model", models)
371+
@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [
372+
(4, 757, 1),
373+
(4, 757, 2),
374+
(16, 1921, 1),
375+
(16, 1921, 2),
376+
])
377+
def test_input_processor_override(input_processor_for_phi3v: Callable,
378+
image_assets: _ImageAssets, model: str,
379+
num_crops: int, expected_toks_per_img: int,
380+
num_imgs: int):
381+
"""Ensure input_processor_for_phi3v handles num_crops properly."""
382+
# Same as the previous test - don't initialize mm_processor_kwargs
383+
# in this test and assume that the kwargs will be correctly expanded by
384+
# the partial when calling the custom input processor.
385+
ctx = build_model_context(
386+
model_name=model,
387+
tokenizer_name=model,
388+
trust_remote_code=True,
389+
)
390+
tokenizer = AutoTokenizer.from_pretrained(model)
391+
# Build the image str / prompt based on the number of images we pass
392+
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
393+
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
394+
images = [image_assets[0].pil_image] * num_imgs
395+
396+
llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt),
397+
prompt=prompt,
398+
multi_modal_data={"image": images})
399+
400+
proc_llm_inputs = input_processor_for_phi3v(
401+
ctx=ctx,
402+
llm_inputs=llm_inputs,
403+
num_crops=num_crops,
404+
)
405+
406+
# Ensure we have the right number of placeholders per num_crops size
407+
img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID)
408+
assert img_tok_count == expected_toks_per_img * num_imgs

vllm/model_executor/models/phi3v.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
307307

308308

309309
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
310-
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
310+
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
311311
transposed = False
312312
if width < height:
313313
width, height = height, width
@@ -337,8 +337,10 @@ def get_phi3v_image_feature_size(
337337
*,
338338
input_height: int,
339339
input_width: int,
340+
num_crops: int,
340341
) -> int:
341-
num_crops = hf_config.get("num_crops", 16)
342+
if num_crops is None:
343+
num_crops = hf_config.get("num_crops", 16)
342344
new_width, new_height = _calc_hd_transform_size(width=input_width,
343345
height=input_height,
344346
hd_num=num_crops)
@@ -347,20 +349,26 @@ def get_phi3v_image_feature_size(
347349
+ (new_height // 336 + 1) * 12
348350

349351

350-
def get_max_phi3v_image_tokens(ctx: InputContext):
352+
def get_max_phi3v_image_tokens(ctx: InputContext,
353+
*,
354+
num_crops: Optional[int] = None):
351355

352356
return get_phi3v_image_feature_size(
353357
ctx.get_hf_image_processor_config(),
354358
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
355359
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
360+
num_crops=num_crops,
356361
)
357362

358363

359-
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
360-
mm_counts: Mapping[str, int]):
364+
def dummy_data_for_phi3v(ctx: InputContext,
365+
seq_len: int,
366+
mm_counts: Mapping[str, int],
367+
*,
368+
num_crops: Optional[int] = None):
361369
num_images = mm_counts["image"]
362370

363-
image_feature_size = get_max_phi3v_image_tokens(ctx)
371+
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
364372

365373
seq_data = dummy_seq_data_for_clip(
366374
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
@@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig,
398406
return image_placeholder_token_ids
399407

400408

401-
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
409+
def input_processor_for_phi3v(ctx: InputContext,
410+
llm_inputs: LLMInputs,
411+
*,
412+
num_crops: Optional[int] = None):
402413
multi_modal_data = llm_inputs.get("multi_modal_data")
403414
if multi_modal_data is None or "image" not in multi_modal_data:
404415
return llm_inputs
@@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
412423
image_feature_size = [
413424
get_phi3v_image_feature_size(hf_config,
414425
input_width=w,
415-
input_height=h)
426+
input_height=h,
427+
num_crops=num_crops)
416428
]
417429
image_data = [image_data]
418430
elif is_list_of(image_data, Image.Image):
@@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
422434
image_feature_size.append(
423435
get_phi3v_image_feature_size(hf_config,
424436
input_width=w,
425-
input_height=h))
437+
input_height=h,
438+
num_crops=num_crops))
426439
elif isinstance(image_data, torch.Tensor):
427440
num_images, image_feature_size, hidden_size = image_data.shape
428441
elif is_list_of(image_data, torch.Tensor):

0 commit comments

Comments
 (0)