Skip to content

Commit

Permalink
[Bugfix] Add phi3v resize for dynamic shape and fix torchvision requi…
Browse files Browse the repository at this point in the history
…rement (vllm-project#5772)

Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
Isotr0py authored and Alvant committed Oct 26, 2024
1 parent 65a4cfd commit 90da4f2
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 5 deletions.
1 change: 1 addition & 0 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@

# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu
torchvision == 0.18.1+cpu # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
2 changes: 2 additions & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
ray >= 2.9
nvidia-ml-py # for pynvml package
torch == 2.3.0
# These must be updated alongside torch
torchvision == 0.18.0 # Required for phi3v processor, also see https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0
1 change: 0 additions & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ peft
requests
ray
sentence-transformers # required for embedding
torchvision # required for the image processor of phi3v

# Benchmarking
aiohttp
Expand Down
4 changes: 4 additions & 0 deletions tests/models/test_phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
def iter_phi3v_configs(model_name: str):
image_hw_to_feature_size = {
(1008, 1344): 1921,
(2016, 2688): 1933,
}

for (h, w), f in image_hw_to_feature_size.items():
Expand Down Expand Up @@ -75,6 +76,9 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
Expand Down
69 changes: 65 additions & 4 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand All @@ -32,9 +35,11 @@
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput

logger = init_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
Expand Down Expand Up @@ -268,7 +273,63 @@ class Phi3VImagePixelInputs(TypedDict):
"""Shape: (batch_size, 2)"""


@MULTIMODAL_REGISTRY.register_image_pixel_input()
# FIXME(Isotr0py): Remove these after dynamic num_img_tokens is supported
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_padded_size(width, height, padding_unit=336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
bottom_padding = target_height - height - top_padding
padded_width = width
padded_height = height + top_padding + bottom_padding
return padded_width, padded_height


# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_hd_transform_size(width, height, hd_num=16):
transposed = False
if width < height:
width, height = height, width
transposed = True

ratio = width / height
scale = 1
while scale * np.ceil(scale / ratio) <= hd_num:
scale += 1
scale -= 1

new_width = int(scale * 336)
new_height = int(new_width / ratio)

padded_width, padded_height = calc_padded_size(new_width, new_height)

if transposed:
padded_width, padded_height = padded_height, padded_width

return padded_width, padded_height


def _image_processor(
data: ImagePixelData,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Dict[str, torch.Tensor]:
image = data.image

if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = vlm_config.image_input_shape
if (w, h) != calc_hd_transform_size(image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)

data.image = image.resize((w, h))

return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_processor(data, model_config, vlm_config)


@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class Phi3VForCausalLM(VisionLanguageModelBase):

Expand Down

0 comments on commit 90da4f2

Please sign in to comment.