From 2dfd320767fc51838abe4ef94313ba6e95cf99b6 Mon Sep 17 00:00:00 2001 From: Shanshan Wang Date: Mon, 28 Oct 2024 06:17:51 +0000 Subject: [PATCH] format Signed-off-by: Shanshan Wang --- examples/offline_inference_vision_language.py | 2 +- ...e_inference_vision_language_multi_image.py | 3 +- vllm/entrypoints/chat_utils.py | 3 +- vllm/model_executor/models/h2ovl.py | 249 +++++++++++------- vllm/model_executor/models/registry.py | 2 +- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 3 +- vllm/transformers_utils/configs/h2ovl.py | 7 +- 8 files changed, 172 insertions(+), 100 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 43b806947947c..43bf187d4c60e 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -502,4 +502,4 @@ def main(args): default=16, help='Number of frames to extract from the video.') args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 15695af846935..d99684078ff3d 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -106,6 +106,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: chat_template=None, ) + def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "h2oai/h2ovl-mississippi-2b" @@ -139,6 +140,7 @@ def load_h2onvl(question: str, image_urls: List[str]) -> ModelRequestData: chat_template=None, ) + def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "OpenGVLab/InternVL2-2B" @@ -318,7 +320,6 @@ def run_generate(model, question: str, image_urls: List[str]): for o in outputs: generated_text = o.outputs[0].text print(generated_text) - def run_chat(model: str, question: str, image_urls: List[str]): diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 221a0e49987bf..256339106eaa3 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -183,7 +183,8 @@ def _placeholder_str(self, modality: ModalityStr, if model_type.startswith("llava"): return self._cached_token_str(self._tokenizer, hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat", "NVLM_D", "h2ovl_chat"): + if model_type in ("chameleon", "internvl_chat", "NVLM_D", + "h2ovl_chat"): return "" if model_type == "mllama": return "<|image|>" diff --git a/vllm/model_executor/models/h2ovl.py b/vllm/model_executor/models/h2ovl.py index 064a3f0df58f8..233ede4a48be3 100644 --- a/vllm/model_executor/models/h2ovl.py +++ b/vllm/model_executor/models/h2ovl.py @@ -5,12 +5,11 @@ # Copyright (c) 2024 H2O.AI # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- -from typing import Optional, Tuple, List, Mapping from functools import partial -from PIL import Image +from typing import List, Optional, Tuple import torch -import torch.nn as nn +from PIL import Image from transformers import PretrainedConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, @@ -22,20 +21,21 @@ from vllm.utils import is_list_of from .intern_vit import InternVisionModel -from .internvl import (InternVLChatModel, - InternVLInputPipeline, - build_transform, - find_closest_aspect_ratio, - get_internvl_num_patches, - get_max_internvl_image_size, - IMG_START, IMG_END, IMG_CONTEXT) +from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel, + InternVLInputPipeline, build_transform, + find_closest_aspect_ratio, get_internvl_num_patches) # modified to include blocks generated in second pass -def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, - max_num: int, image_size: int, - use_thumbnail: bool, - prior_aspect_ratio=None) -> Tuple[int, int, int, Tuple[int, int]]: +def calculate_num_blocks( + orig_width: int, + orig_height: int, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio=None, +) -> Tuple[int, int, int, Tuple[int, int]]: aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio @@ -46,8 +46,10 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, # if prior_aspect_ratio is provided, filter the target ratios if prior_aspect_ratio is not None: - target_ratios = [ratio for ratio in target_ratios if - prior_aspect_ratio[0] % ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0] + target_ratios = [ + ratio for ratio in target_ratios if prior_aspect_ratio[0] % + ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0 + ] # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, @@ -65,27 +67,35 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B -def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, - image_size: int, - use_thumbnail: bool) -> Tuple[List[Image.Image], Tuple[int, int]]: +def dynamic_preprocess( + image: Image.Image, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, +) -> Tuple[List[Image.Image], Tuple[int, int]]: orig_width, orig_height = image.size # calculate the number of blocks without thumbnail - blocks, target_width, target_height, target_aspect_ratio = calculate_num_blocks( - orig_width, - orig_height, - min_num, - max_num, - image_size, - use_thumbnail=False) + blocks, target_width, target_height, target_aspect_ratio = ( + calculate_num_blocks( + orig_width, + orig_height, + min_num, + max_num, + image_size, + use_thumbnail=False, + )) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -97,8 +107,14 @@ def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, # new dynamic_preprocess2 with prior_aspect_ratio -def dynamic_preprocess2(image: Image.Image, min_num: int, max_num: int, - image_size: int, use_thumbnail: bool, prior_aspect_ratio: Tuple[int, int]) -> List[Image.Image]: +def dynamic_preprocess2( + image: Image.Image, + min_num: int, + max_num: int, + image_size: int, + use_thumbnail: bool, + prior_aspect_ratio: Tuple[int, int], +) -> List[Image.Image]: orig_width, orig_height = image.size # calculate the number of blocks based on prior aspect ratio @@ -109,15 +125,18 @@ def dynamic_preprocess2(image: Image.Image, min_num: int, max_num: int, max_num, image_size, use_thumbnail=False, - prior_aspect_ratio=prior_aspect_ratio) + prior_aspect_ratio=prior_aspect_ratio, + ) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): - box = ((i % (target_width // image_size)) * image_size, - (i // (target_width // image_size)) * image_size, - ((i % (target_width // image_size)) + 1) * image_size, - ((i // (target_width // image_size)) + 1) * image_size) + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) @@ -128,45 +147,82 @@ def dynamic_preprocess2(image: Image.Image, min_num: int, max_num: int, return processed_images -def load_image1(image:Image.Image, input_size=448, min_num=1, max_num=6): +def load_image1(image: Image.Image, input_size=448, min_num=1, max_num=6): # image = Image.open(image_file).convert('RGB') transform = build_transform(input_size=input_size) - images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num) + images, target_aspect_ratio = dynamic_preprocess( + image, + image_size=input_size, + use_thumbnail=True, + min_num=min_num, + max_num=max_num, + ) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values, target_aspect_ratio -def load_image2(image:Image.Image, input_size=448, min_num=1, max_num=6, target_aspect_ratio=None): + +def load_image2( + image: Image.Image, + input_size=448, + min_num=1, + max_num=6, + target_aspect_ratio=None, +): # image = Image.open(image_file).convert('RGB') transform = build_transform(input_size=input_size) - images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio) + images = dynamic_preprocess2( + image, + image_size=input_size, + use_thumbnail=True, + min_num=min_num, + max_num=max_num, + prior_aspect_ratio=target_aspect_ratio, + ) pixel_values = [transform(image) for image in images] pixel_values = torch.stack(pixel_values) return pixel_values -def image_to_pixel_values(image:Image.Image, - input_size: int, min_num: int, - max_num: int, use_thumbnail: bool, - use_MSAC: bool) -> torch.Tensor: +def image_to_pixel_values( + image: Image.Image, + input_size: int, + min_num: int, + max_num: int, + use_thumbnail: bool, + use_MSAC: bool, +) -> torch.Tensor: # when MSAC is turned on, we need to process the image twice if use_MSAC: - pixel_values, target_aspect_ratio = load_image1(image, input_size=input_size, min_num=min_num, max_num=max_num) - pixel_values2 = load_image2(image, input_size=input_size, min_num=min_num, max_num=max_num, target_aspect_ratio=target_aspect_ratio) - pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0) + pixel_values, target_aspect_ratio = load_image1(image, + input_size=input_size, + min_num=min_num, + max_num=max_num) + pixel_values2 = load_image2( + image, + input_size=input_size, + min_num=min_num, + max_num=max_num, + target_aspect_ratio=target_aspect_ratio, + ) + pixel_values = torch.cat( + [pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0) else: transform = build_transform(input_size=input_size) - images, _ = dynamic_preprocess(image, - min_num=min_num, - max_num=max_num, - image_size=input_size, - use_thumbnail=use_thumbnail) + images, _ = dynamic_preprocess( + image, + min_num=min_num, + max_num=max_num, + image_size=input_size, + use_thumbnail=use_thumbnail, + ) pixel_values = [transform(image) for image in images] - pixel_values = torch.stack(pixel_values) + pixel_values = torch.stack(pixel_values) return pixel_values + def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, max_dynamic_patch: Optional[int] = None): image_size = hf_config.vision_config.image_size @@ -175,12 +231,14 @@ def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail use_MSAC = hf_config.use_msac - return partial(image_to_pixel_values, - input_size=image_size, - min_num=min_num, - max_num=max_dynamic_patch, - use_thumbnail=use_thumbnail, - use_MSAC=use_MSAC) + return partial( + image_to_pixel_values, + input_size=image_size, + min_num=min_num, + max_num=max_dynamic_patch, + use_thumbnail=use_thumbnail, + use_MSAC=use_MSAC, + ) def get_max_internvl_image_tokens(ctx: InputContext, @@ -192,12 +250,12 @@ def get_max_internvl_image_tokens(ctx: InputContext, hf_config = ctx.get_hf_config() use_thumbnail = hf_config.use_thumbnail use_MSAC = hf_config.use_msac - + if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch num_patches = get_internvl_num_patches(hf_config) - + coefficient = 2 if use_MSAC else 1 num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0) @@ -205,32 +263,35 @@ def get_max_internvl_image_tokens(ctx: InputContext, class H2OVLInputPipeline(InternVLInputPipeline): + def __init__(self): super().__init__(IMG_START, IMG_END, IMG_CONTEXT) + def input_processor( - self, + self, ctx: InputContext, inputs: DecoderOnlyInputs, - *, + *, max_dynamic_patch: Optional[int] = None, ) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return inputs - + model_config = ctx.model_config hf_config = ctx.get_hf_config() - + image_data = multi_modal_data["image"] num_patches = get_internvl_num_patches(hf_config) - + # can only get the total blocks num after the image fully processed - num_blocks_calculator = image_to_pixel_values_wrapper(hf_config, max_dynamic_patch=max_dynamic_patch) - + num_blocks_calculator = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch=max_dynamic_patch) + if isinstance(image_data, Image.Image): num_blocks = num_blocks_calculator(image_data).shape[0] image_feature_sizes = [num_blocks * num_patches] - + elif is_list_of(image_data, Image.Image): # Do not use MSAC for multi images hf_config.use_msac = False @@ -238,64 +299,70 @@ def input_processor( for image in image_data: num_blocks = num_blocks_calculator(image).shape[0] image_feature_sizes.append(num_blocks * num_patches) - + elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape image_feature_sizes = [image_feature_size] else: raise TypeError(f"Invalid image type: {type(image_data)}") - + tokenizer = cached_get_tokenizer( model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - + trust_remote_code=model_config.trust_remote_code, + ) + prompt = inputs.get("prompt") prompt_token_ids = inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - + new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) - return token_inputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) - + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + ) + def input_mapper( - self, + self, ctx: InputContext, data: object, *, max_dynamic_patch: Optional[int] = None, ): hf_config = ctx.get_hf_config() - + image_pixel_values_mapper = image_to_pixel_values_wrapper( hf_config, max_dynamic_patch) - + if isinstance(data, Image.Image): data = image_pixel_values_mapper(data) data = data.unsqueeze(0) elif is_list_of(data, Image.Image): - hf_config.use_msac = False + hf_config.use_msac = False data = [image_pixel_values_mapper(img) for img in data] - + else: return MultiModalInputs({"image_embeds": data}) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, - trust_remote_code=model_config.trust_remote_code) - image_token_id = tokenizer.encode(self.img_context_token, - add_special_tokens=False, - return_tensors="pt")[0] + trust_remote_code=model_config.trust_remote_code, + ) + image_token_id = tokenizer.encode( + self.img_context_token, + add_special_tokens=False, + return_tensors="pt", + )[0] return MultiModalInputs({ "pixel_values": data, "image_token_id": image_token_id }) - + input_pipeline = H2OVLInputPipeline() @@ -317,8 +384,8 @@ def _init_vision_model( if not is_mono: vision_feature_layer = config.select_layer if vision_feature_layer < 0: - num_hidden_layers = config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 + num_hidden_layers = (config.vision_config.num_hidden_layers + + vision_feature_layer + 1) else: num_hidden_layers = vision_feature_layer + 1 diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fc25ed8ea82a7..cfb35838ca0c0 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -461,4 +461,4 @@ def _run() -> None: if __name__ == "__main__": - _run() + _run() \ No newline at end of file diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 33ca5d0b2f639..08697274854e0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -18,7 +18,8 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - EAGLEConfig, ExaoneConfig, H2OVLChatConfig, + EAGLEConfig, ExaoneConfig, + H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index aedfe925a07cd..d1e19c9a33c24 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -23,6 +23,7 @@ "DbrxConfig", "MPTConfig", "RWConfig", + "H2OVLChatConfig", "InternVLChatConfig", "JAISConfig", "MedusaConfig", @@ -34,4 +35,4 @@ "NVLM_D_Config", "SolarConfig", "UltravoxConfig", -] +] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/h2ovl.py b/vllm/transformers_utils/configs/h2ovl.py index 7d4941c0120b8..b94c5b77e4b7f 100644 --- a/vllm/transformers_utils/configs/h2ovl.py +++ b/vllm/transformers_utils/configs/h2ovl.py @@ -1,12 +1,13 @@ # Adapted from # https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py # -------------------------------------------------------- -# H2OVL -# Copyright (c) 2024 H2O.ai +# H2OVL-Mississippi +# Copyright (c) 2024 H2O.AI # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- from .internvl import InternVLChatConfig + class H2OVLChatConfig(InternVLChatConfig): - model_type = "h2ovl_chat" \ No newline at end of file + model_type = "h2ovl_chat"