Skip to content

Commit

Permalink
[Model] Initial support for LLaVA-NeXT (vllm-project#4199)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
2 people authored and jimpang committed Jul 24, 2024
1 parent 9c4eed0 commit 11e9467
Show file tree
Hide file tree
Showing 7 changed files with 640 additions and 18 deletions.
6 changes: 5 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ Alongside each architecture, we include some popular models that use it.
- ✅︎
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc.
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
-
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
Expand Down
2 changes: 0 additions & 2 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def iter_llava_configs(model_name: str):

model_and_vl_config = [
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
# Not enough memory
# *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
]


Expand Down
123 changes: 123 additions & 0 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import List, Tuple

import pytest
from transformers import AutoTokenizer

from vllm.config import VisionLanguageConfig

from ..conftest import IMAGE_FILES

pytestmark = pytest.mark.llava

_PREFACE = (
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions.")

# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = [
f"{_PREFACE} <image>\nUSER: What's the content of the image? ASSISTANT:",
f"{_PREFACE} <image>\nUSER: What is the season? ASSISTANT:",
]

assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)


def iter_llava_next_configs(model_name: str):
image_hw_to_feature_size = {
(336, 336): 1176,
(672, 672): 2928,
(1344, 336): 1944,
(336, 1344): 1890,
}

for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))


model_and_vl_config = [
*iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"),
]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
vlm_config: VisionLanguageConfig, model_id: str):
"""Sanitize vllm output to be comparable with hf output.
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
It also reduces `output_str` from "<image><image>bla" to "bla".
"""
input_ids, output_str = vllm_output
image_token_id = vlm_config.image_token_id

tokenizer = AutoTokenizer.from_pretrained(model_id)
image_token_str = tokenizer.decode(image_token_id)

hf_input_ids = [
input_id for idx, input_id in enumerate(input_ids)
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
]
hf_output_str = output_str \
.replace(image_token_str * vlm_config.image_feature_size, " ")

return hf_input_ids, hf_output_str


@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
model_and_config, dtype: str, max_tokens: int) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config

with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)

vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]

with vllm_runner(
model_id,
dtype=dtype,
# should be greater than image_feature_size
max_model_len=4096,
enforce_eager=True,
**vlm_config.as_cli_args_dict(),
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)

for i in range(len(HF_IMAGE_PROMPTS)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
vllm_outputs[i], vlm_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
62 changes: 55 additions & 7 deletions tests/multimodal/test_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pytest
from transformers import CLIPImageProcessor
from transformers import CLIPImageProcessor, LlavaNextImageProcessor

from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
Expand All @@ -12,7 +12,7 @@
@pytest.mark.parametrize("dtype", ["half", "float"])
def test_clip_image_processor(hf_images, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 33
IMAGE_HEIGHT = IMAGE_WIDTH = 560

hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, CLIPImageProcessor)
Expand Down Expand Up @@ -55,10 +55,61 @@ def test_clip_image_processor(hf_images, dtype):
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"


@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("dtype", ["half", "float"])
def test_llava_next_image_processor(hf_images, dtype):
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 560

hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
assert isinstance(hf_processor, LlavaNextImageProcessor)

model_config = ModelConfig(
model=MODEL_NAME,
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
)

for image in hf_images:
hf_result = hf_processor.preprocess(
image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
ImagePixelData(image),
model_config=model_config,
vlm_config=vlm_config,
)

assert hf_result.keys() == vllm_result.keys()
for key, hf_tensor in hf_result.items():
hf_arr: np.ndarray = hf_tensor.numpy()
vllm_arr: np.ndarray = vllm_result[key].numpy()

assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"


@pytest.mark.xfail(
reason="Example image pixels were not processed using HuggingFace")
@pytest.mark.parametrize("dtype", ["float"])
def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
IMAGE_HEIGHT = IMAGE_WIDTH = 33
IMAGE_HEIGHT = IMAGE_WIDTH = 560

model_config = ModelConfig(
model=MODEL_NAME,
Expand Down Expand Up @@ -95,7 +146,4 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
tensor_arr: np.ndarray = tensor_result[key].numpy()

assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"

# The examples in PR#3042 have slightly different preprocessing from
# HuggingFace's LlavaProcessor, causing the test to fail.
# assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
Expand Down
18 changes: 10 additions & 8 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
from torch import nn
import torch.nn as nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
Expand Down Expand Up @@ -51,10 +51,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
return hidden_states


def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)

Expand Down Expand Up @@ -151,7 +151,8 @@ def _parse_and_validate_image_input(
return None

if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values")
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

return LlavaImagePixelInputs(
type="pixel_values",
Expand All @@ -166,7 +167,8 @@ def _parse_and_validate_image_input(
return None

if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features")
raise ValueError("Incorrect type of image features. "
f"Got type: {type(image_features)}")

return LlavaImageFeatureInputs(
type="image_features",
Expand Down Expand Up @@ -268,7 +270,7 @@ def forward(
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)

inputs_embeds = _merge_vision_embeddings(
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)

Expand Down
Loading

0 comments on commit 11e9467

Please sign in to comment.