From cb197433a043f3bab32f1bb8eb9c6ba870d06d0f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 19 Apr 2024 07:49:25 +0000 Subject: [PATCH] Fix LLaVA example and test w.r.t. image processing refactor - Note that we now load the images directly instead of from `.pt` files --- examples/llava_example.py | 6 +-- tests/conftest.py | 22 ++++----- tests/models/test_llava.py | 93 ++++++++++++++++++++++++-------------- 3 files changed, 70 insertions(+), 51 deletions(-) diff --git a/examples/llava_example.py b/examples/llava_example.py index 1c25d3fe9491b..4da3653235230 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -18,16 +18,13 @@ def run_llava_pixel_values(): image_token_id=32000, image_input_shape="1,3,336,336", image_feature_size=576, - no_image_processor=True, ) prompt = "" * 576 + ( "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - image_tensor: torch.Tensor = torch.load("images/stop_sign_pixel_values.pt") - image_arr = image_tensor.view(3, 336, 336).permute((1, 2, 0)).numpy() - image = Image.fromarray(image_arr, mode="RGB") + image = Image.open("images/stop_sign.jpg") outputs = llm.generate(prompt, multi_modal_datas=ImagePixelData(image)) for o in outputs: @@ -42,7 +39,6 @@ def run_llava_image_features(): image_token_id=32000, image_input_shape="1,576,1024", image_feature_size=576, - no_image_processor=True, ) prompt = "" * 576 + ( diff --git a/tests/conftest.py b/tests/conftest.py index acd1caa30b75c..628df9aaf6c38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel -from vllm.sequence import MultiModalData +from vllm.sequence import ImageFeatureData, ImagePixelData, MultiModalData from vllm.transformers_utils.tokenizer import get_tokenizer _TEST_DIR = os.path.dirname(__file__) @@ -21,10 +21,6 @@ _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] # Multi modal related -_PIXEL_VALUES_FILES = [ - os.path.join(_TEST_DIR, "images", filename) for filename in - ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"] -] _IMAGE_FEATURES_FILES = [ os.path.join(_TEST_DIR, "images", filename) for filename in ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"] @@ -37,8 +33,7 @@ "\nUSER: What's the content of the image?\nASSISTANT:", "\nUSER: What is the season?\nASSISTANT:" ] -assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len( - _IMAGE_FILES) == len(_IMAGE_PROMPTS) +assert len(_IMAGE_FEATURES_FILES) == len(_IMAGE_FILES) == len(_IMAGE_PROMPTS) def _read_prompts(filename: str) -> List[str]: @@ -86,15 +81,18 @@ def hf_images() -> List[Image.Image]: @pytest.fixture() -def vllm_images(request) -> List[torch.Tensor]: +def vllm_images(request) -> List[MultiModalData]: vision_language_config = request.getfixturevalue("model_and_config")[1] if vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): - filenames = _IMAGE_FEATURES_FILES + return [ + ImageFeatureData(torch.load(filename)) + for filename in _IMAGE_FEATURES_FILES + ] else: - filenames = _PIXEL_VALUES_FILES - - return [torch.load(filename) for filename in filenames] + return [ + ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES + ] @pytest.fixture() diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 18d608af976b9..2e2a6faa18b59 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -8,36 +8,47 @@ from transformers import AutoTokenizer from vllm.config import VisionLanguageConfig -from vllm.sequence import ImagePixelData def iter_llava_configs(model_name: str): - for input_type, input_shape in [ - (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, 336, 336)), - (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, 576, 1024)), - ]: - yield (model_name, - VisionLanguageConfig(image_input_type=input_type, - image_feature_size=576, - image_token_id=32000, - image_input_shape=input_shape, - image_processor=None, - image_processor_revision=None)) + image_hw_to_feature_size = { + (336, 336): 576, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) def iter_llava_next_configs(model_name: str): - for input_type, input_shape in [ - # `vision_config` on HuggingFace only supports `image_size=336` - (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, 336, 336)), - (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, 576, 1024)), - ]: - yield (model_name, - VisionLanguageConfig(image_input_type=input_type, - image_feature_size=576, - image_token_id=64000, - image_input_shape=input_shape, - image_processor=None, - image_processor_revision=None)) + image_hw_to_feature_size = { + (336, 336): 1176, + (672, 672): 2928, + (1344, 336): 1944, + (336, 1344): 1890, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + (VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=64000, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) model_and_vl_config = [ @@ -99,17 +110,27 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, """Inference result should be the same between hf and vllm. All the image fixtures for the test is under tests/images. - For huggingface runner, we provide the raw images as input. - For vllm runner, we provide image tensors and corresponding + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding vision language config as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ model_id, vision_language_config = model_and_config + hf_model = hf_runner(model_id, dtype=dtype) - hf_outputs = hf_model.generate_greedy(hf_image_prompts, - max_tokens, - images=hf_images) + _, vision_language_config = model_and_config + if vision_language_config.image_input_type == ( + VisionLanguageConfig.ImageInputType.IMAGE_FEATURES): + # HuggingFace does not support image feature input + hf_outputs = [None] * len(hf_image_prompts) + else: + _, _, h, w = vision_language_config.image_input_shape + hf_outputs = hf_model.generate_greedy( + hf_image_prompts, + max_tokens, + # To be compatible with the patch for LLaVA-NeXT + images=[im.resize((w, h)) for im in hf_images]) del hf_model vllm_model = vllm_runner(model_id, @@ -117,19 +138,23 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images, worker_use_ray=worker_use_ray, enforce_eager=True, **as_dict(vision_language_config)) - vllm_outputs = vllm_model.generate_greedy( - vllm_image_prompts, - max_tokens, - multi_modal_datas=[ImagePixelData(image) for image in vllm_images]) + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, + max_tokens, + multi_modal_datas=vllm_images) del vllm_model gc.collect() torch.cuda.empty_cache() for i in range(len(hf_image_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] + hf_output = hf_outputs[i] + if hf_output is None: + continue + + hf_output_ids, hf_output_str = hf_output vllm_output_ids, vllm_output_str = sanitize_vllm_output( vllm_outputs[i], vision_language_config, model_id) + print(f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_str == vllm_output_str, ( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, (