Skip to content

[VLM] Merged multi-modal processor for Pixtral #12211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Mar 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
fbe6a9d
Adjustment first version
Flechman Jan 14, 2025
46c142f
Merge with main
Flechman Jan 22, 2025
4af1716
Revert changes
Flechman Jan 26, 2025
8a75f3a
Add pixtral dummy inputs builder
Flechman Jan 26, 2025
2e346d3
Fix naming
Flechman Jan 26, 2025
c9c082b
HF processor not supported
Flechman Jan 27, 2025
869a620
Add tokenizer mode
Flechman Jan 27, 2025
a6392cb
Override pixtral processor apply
Flechman Jan 27, 2025
c1b78f4
Merge with main
Flechman Feb 14, 2025
9d70fba
Merge with main
Flechman Mar 9, 2025
cafe731
Add caching mechanism
Flechman Mar 9, 2025
4c8f915
Add tokenization
Flechman Mar 9, 2025
c1bef45
Cleanup previous processor
Flechman Mar 9, 2025
d5fd5cd
Update based on latest PRs
DarkLight1337 Mar 10, 2025
9b0e436
Draft HF-compatible processor
DarkLight1337 Mar 10, 2025
a8d00e8
Add sanity check
DarkLight1337 Mar 10, 2025
6bed6d8
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 13, 2025
bfac110
Separate tests based on tokenizer type
DarkLight1337 Mar 13, 2025
2a30b0c
Cap number of sequences to fit in memory
DarkLight1337 Mar 13, 2025
fafe381
Update
DarkLight1337 Mar 13, 2025
128107e
Update
DarkLight1337 Mar 13, 2025
1fd4c54
Clean up
DarkLight1337 Mar 13, 2025
cc0edbe
Get the model to run
DarkLight1337 Mar 13, 2025
efd5d2c
Fix multi-image input
DarkLight1337 Mar 13, 2025
0022e35
More fixes
DarkLight1337 Mar 13, 2025
e332e17
Try to fix an edge case
DarkLight1337 Mar 13, 2025
b855a05
Fix
DarkLight1337 Mar 13, 2025
c9bb46d
Add a note
DarkLight1337 Mar 13, 2025
78aee1b
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 14, 2025
2bb3c4c
Auto-load chat template for Pixtral-HF
DarkLight1337 Mar 14, 2025
620a376
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 14, 2025
f40dc3a
Fix
DarkLight1337 Mar 14, 2025
2a8814b
Flatten
DarkLight1337 Mar 14, 2025
a68d328
Fix type annotation
DarkLight1337 Mar 14, 2025
9782fec
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 14, 2025
04e13b1
Clean
DarkLight1337 Mar 14, 2025
6a14d24
Fix type error
DarkLight1337 Mar 14, 2025
025a237
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 14, 2025
8566c27
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 14, 2025
083bc2f
Try fix V1
DarkLight1337 Mar 15, 2025
7a72365
Update
DarkLight1337 Mar 15, 2025
d216e16
Update
DarkLight1337 Mar 15, 2025
465541a
Rename and simplify
DarkLight1337 Mar 15, 2025
e653095
Optimize
DarkLight1337 Mar 15, 2025
3ec39ff
Fix
DarkLight1337 Mar 15, 2025
68eb2d6
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 15, 2025
1373b39
Fix V0 inference
DarkLight1337 Mar 15, 2025
b12f969
Merge branch 'main' into pixtral-mm-processor
DarkLight1337 Mar 15, 2025
8186f96
Fix
DarkLight1337 Mar 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions examples/offline_inference/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,18 @@
# python demo.py advanced


def run_simple_demo():
def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192)

# Lower max_num_seqs or max_model_len on low-VRAM GPUs.
llm = LLM(model=model_name, tokenizer_mode="mistral")
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
llm = LLM(
model=model_name,
tokenizer_mode="mistral",
max_model_len=4096,
max_num_seqs=2,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"
Expand Down Expand Up @@ -76,7 +82,7 @@ def run_simple_demo():
print(outputs[0].outputs[0].text)


def run_advanced_demo():
def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5
max_tokens_per_img = 4096
Expand All @@ -87,6 +93,7 @@ def run_advanced_demo():
tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg},
max_model_len=max_img_per_msg * max_tokens_per_img,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

prompt = "Describe the following image."
Expand Down Expand Up @@ -153,14 +160,19 @@ def main():
help="Specify the demo mode: 'simple' or 'advanced'",
)

parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disables caching of multi-modal preprocessor/mapper.')

args = parser.parse_args()

if args.mode == "simple":
print("Running simple demo...")
run_simple_demo()
run_simple_demo(args)
elif args.mode == "advanced":
print("Running advanced demo...")
run_advanced_demo()
run_advanced_demo(args)


if __name__ == "__main__":
Expand Down
200 changes: 149 additions & 51 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@

import copy
from functools import partial
from typing import Optional
from typing import Optional, Union

import numpy as np
import pytest
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import ProcessingCache
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)

from ....multimodal.utils import random_audio, random_image, random_video
from ...registry import HF_EXAMPLE_MODELS
Expand Down Expand Up @@ -85,14 +91,6 @@ def _test_processing_correctness(
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
}

tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}

for batch_idx in range(num_batches):
mm_data = {
k:
Expand All @@ -115,43 +113,131 @@ def _test_processing_correctness(
elif len(mm_data[k]) == 1:
mm_data[k] = mm_data[k][0]

baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
cached_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert _drop_mm_kwargs_keys(
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
baseline_tokenized_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert _drop_mm_kwargs_keys(
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
cached_tokenized_result, ignore_mm_keys), (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
if isinstance(tokenizer, MistralTokenizer):
_test_processing_correctness_mistral(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)
else:
_test_processing_correctness_hf(
model_config,
tokenizer,
prompt,
mm_data,
baseline_processor,
cached_processor,
batch_idx,
ignore_mm_keys=ignore_mm_keys,
)


def _test_processing_correctness_hf(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
# For some multimodal models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
else:
token_prompt = tokenizer.encode(prompt)

baseline_result = baseline_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_result = cached_processor.apply(
prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert _inputs_equal(
baseline_result,
cached_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"

baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert _inputs_equal(
baseline_result,
baseline_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"

cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert _inputs_equal(
cached_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"


def _test_processing_correctness_mistral(
model_config: ModelConfig,
tokenizer: MistralTokenizer,
prompt: str,
mm_data: MultiModalDataDict,
baseline_processor: BaseMultiModalProcessor,
cached_processor: BaseMultiModalProcessor,
batch_idx: int,
ignore_mm_keys: Optional[list[str]] = None,
):
images = mm_data.get("image", [])
if not isinstance(images, list):
images = [images]

request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=prompt),
*(ImageChunk(image=image) for image in images),
]),
])
res = tokenizer.mistral.encode_chat_completion(request)
token_prompt = res.tokens

# Mistral chat outputs tokens directly, rather than text prompts
baseline_tokenized_result = baseline_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
token_prompt,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)

assert _inputs_equal(
baseline_tokenized_result,
cached_tokenized_result,
ignore_mm_keys,
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"


# yapf: disable
Expand All @@ -173,6 +259,7 @@ def _test_processing_correctness(
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistralai/Pixtral-12B-2409",
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
Expand Down Expand Up @@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
)


def _drop_mm_kwargs_keys(result: dict,
ignore_mm_keys: Optional[list[str]] = None) -> dict:
def _inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
):
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
b, ignore_mm_keys)


def _drop_mm_kwargs_keys(
result: MultiModalInputs,
ignore_mm_keys: Optional[list[str]] = None,
) -> MultiModalInputs:
"""Drop specified keys from result['mm_kwargs'].

This is mainly to avoid doing exact match of audio_features in ultravox.
Expand Down
Loading