diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 1692e13c4ec06..7a9c87f406c66 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -225,9 +225,9 @@ Multimodal Language Models - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - * - :code:`Phi3VForCausalLM` - - Phi-3-Vision + - Phi-3-Vision, Phi-3.5-Vision - Image - - :code:`microsoft/Phi-3-vision-128k-instruct`, etc. + - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - * - :code:`MiniCPMV` - MiniCPM-V diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index ccfc98a325982..197e63b1b1e52 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -21,7 +21,7 @@ "<|user|>\n<|image_1|>\nWhat is the season?<|end|>\n<|assistant|>\n", }) -models = ["microsoft/Phi-3-vision-128k-instruct"] +models = ["microsoft/Phi-3.5-vision-instruct"] def vllm_to_hf_output(vllm_output: Tuple[List[int], str, diff --git a/vllm/config.py b/vllm/config.py index 7e62a727115ef..4cbdde5e113a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -13,7 +13,9 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import get_config, get_hf_text_config +from vllm.transformers_utils.config import (get_config, + get_hf_image_processor_config, + get_hf_text_config) from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_cpu, is_hip, is_neuron, is_openvino, is_xpu, @@ -167,6 +169,8 @@ def __init__( self.hf_config = get_config(self.model, trust_remote_code, revision, code_revision, rope_scaling, rope_theta) self.hf_text_config = get_hf_text_config(self.hf_config) + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) # Choose a default enforce_eager value if the user did not specify diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index deb66f0b0cb35..ae6c6c05d9f72 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -2,8 +2,8 @@ from array import array from collections import UserDict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Callable, Dict, Mapping, Optional, Protocol, - Tuple, Type) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, + Protocol, Tuple, Type) from torch import nn from transformers import PretrainedConfig @@ -55,6 +55,13 @@ def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: return hf_config + def get_hf_image_processor_config(self) -> Dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + + return self.model_config.hf_image_processor_config + N = TypeVar("N", bound=Type[nn.Module]) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 9ccd6ef6d9ace..4854377215608 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -15,8 +15,8 @@ # limitations under the License. import re from functools import lru_cache -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union) +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict, Union) import numpy as np import torch @@ -324,12 +324,12 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181 def get_phi3v_image_feature_size( - hf_config: PretrainedConfig, + hf_config: Dict[str, Any], *, input_height: int, input_width: int, ) -> int: - num_crops = getattr(hf_config, "num_crops", 16) + num_crops = hf_config.get("num_crops", 16) new_width, new_height = _calc_hd_transform_size(width=input_width, height=input_height, hd_num=num_crops) @@ -341,7 +341,7 @@ def get_phi3v_image_feature_size( def get_max_phi3v_image_tokens(ctx: InputContext): return get_phi3v_image_feature_size( - ctx.get_hf_config(), + ctx.get_hf_image_processor_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, ) @@ -395,7 +395,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config - hf_config = ctx.get_hf_config() + hf_config = ctx.get_hf_image_processor_config() image_data = multi_modal_data["image"] if isinstance(image_data, Image.Image): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index d3024965c0b4c..0f86b02deb21a 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,8 +1,10 @@ import contextlib from pathlib import Path -from typing import Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, Union from transformers import GenerationConfig, PretrainedConfig +from transformers.models.auto.image_processing_auto import ( + get_image_processor_config) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) @@ -98,6 +100,17 @@ def get_config( return config +def get_hf_image_processor_config( + model: Union[str, Path], + revision: Optional[str] = None, + **kwargs, +) -> Dict[str, Any]: + # Separate model folder from file path for GGUF models + if Path(model).is_file() and Path(model).suffix == ".gguf": + model = Path(model).parent + return get_image_processor_config(model, revision=revision, **kwargs) + + def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models.