Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit f932e32

Browse files
DarkLight1337ywang96
authored andcommitted
[Model] Initial support for LLaVA-NeXT (vllm-project#4199)
Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent 1b41d11 commit f932e32

File tree

7 files changed

+640
-18
lines changed

7 files changed

+640
-18
lines changed

docs/source/models/supported_models.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ Alongside each architecture, we include some popular models that use it.
8989
- ✅︎
9090
* - :code:`LlavaForConditionalGeneration`
9191
- LLaVA-1.5
92-
- :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc.
92+
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
93+
-
94+
* - :code:`LlavaNextForConditionalGeneration`
95+
- LLaVA-NeXT
96+
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
9397
-
9498
* - :code:`MiniCPMForCausalLM`
9599
- MiniCPM

tests/models/test_llava.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ def iter_llava_configs(model_name: str):
3939

4040
model_and_vl_config = [
4141
*iter_llava_configs("llava-hf/llava-1.5-7b-hf"),
42-
# Not enough memory
43-
# *iter_llava_configs("llava-hf/llava-1.5-13b-hf"),
4442
]
4543

4644

tests/models/test_llava_next.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from typing import List, Tuple
2+
3+
import pytest
4+
from transformers import AutoTokenizer
5+
6+
from vllm.config import VisionLanguageConfig
7+
8+
from ..conftest import IMAGE_FILES
9+
10+
pytestmark = pytest.mark.llava
11+
12+
_PREFACE = (
13+
"A chat between a curious human and an artificial intelligence assistant. "
14+
"The assistant gives helpful, detailed, and polite answers to the human's "
15+
"questions.")
16+
17+
# The image token is placed before "user" on purpose so that the test can pass
18+
HF_IMAGE_PROMPTS = [
19+
f"{_PREFACE} <image>\nUSER: What's the content of the image? ASSISTANT:",
20+
f"{_PREFACE} <image>\nUSER: What is the season? ASSISTANT:",
21+
]
22+
23+
assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
24+
25+
26+
def iter_llava_next_configs(model_name: str):
27+
image_hw_to_feature_size = {
28+
(336, 336): 1176,
29+
(672, 672): 2928,
30+
(1344, 336): 1944,
31+
(336, 1344): 1890,
32+
}
33+
34+
for (h, w), f in image_hw_to_feature_size.items():
35+
for input_type, input_shape in [
36+
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
37+
]:
38+
yield (model_name,
39+
VisionLanguageConfig(image_input_type=input_type,
40+
image_feature_size=f,
41+
image_token_id=32000,
42+
image_input_shape=input_shape,
43+
image_processor=model_name,
44+
image_processor_revision=None))
45+
46+
47+
model_and_vl_config = [
48+
*iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"),
49+
]
50+
51+
52+
def vllm_to_hf_output(vllm_output: Tuple[List[int], str],
53+
vlm_config: VisionLanguageConfig, model_id: str):
54+
"""Sanitize vllm output to be comparable with hf output.
55+
The function reduces `input_ids` from 1, 32000, 32000, ..., 32000,
56+
x1, x2, x3 ... to 1, 32000, x1, x2, x3 ...
57+
It also reduces `output_str` from "<image><image>bla" to "bla".
58+
"""
59+
input_ids, output_str = vllm_output
60+
image_token_id = vlm_config.image_token_id
61+
62+
tokenizer = AutoTokenizer.from_pretrained(model_id)
63+
image_token_str = tokenizer.decode(image_token_id)
64+
65+
hf_input_ids = [
66+
input_id for idx, input_id in enumerate(input_ids)
67+
if input_id != image_token_id or input_ids[idx - 1] != image_token_id
68+
]
69+
hf_output_str = output_str \
70+
.replace(image_token_str * vlm_config.image_feature_size, " ")
71+
72+
return hf_input_ids, hf_output_str
73+
74+
75+
@pytest.mark.xfail(
76+
reason="Inconsistent image processor being used due to lack "
77+
"of support for dynamic image token replacement")
78+
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
79+
@pytest.mark.parametrize("dtype", ["half"])
80+
@pytest.mark.parametrize("max_tokens", [128])
81+
def test_models(hf_runner, vllm_runner, hf_images, vllm_images,
82+
model_and_config, dtype: str, max_tokens: int) -> None:
83+
"""Inference result should be the same between hf and vllm.
84+
85+
All the image fixtures for the test is under tests/images.
86+
For huggingface runner, we provide the PIL images as input.
87+
For vllm runner, we provide MultiModalData objects and corresponding
88+
vision language config as input.
89+
Note, the text input is also adjusted to abide by vllm contract.
90+
The text output is sanitized to be able to compare with hf.
91+
"""
92+
model_id, vlm_config = model_and_config
93+
94+
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
95+
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
96+
max_tokens,
97+
images=hf_images)
98+
99+
vllm_image_prompts = [
100+
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
101+
for p in HF_IMAGE_PROMPTS
102+
]
103+
104+
with vllm_runner(
105+
model_id,
106+
dtype=dtype,
107+
# should be greater than image_feature_size
108+
max_model_len=4096,
109+
enforce_eager=True,
110+
**vlm_config.as_cli_args_dict(),
111+
) as vllm_model:
112+
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
113+
max_tokens,
114+
images=vllm_images)
115+
116+
for i in range(len(HF_IMAGE_PROMPTS)):
117+
hf_output_ids, hf_output_str = hf_outputs[i]
118+
vllm_output_ids, vllm_output_str = vllm_to_hf_output(
119+
vllm_outputs[i], vlm_config, model_id)
120+
assert hf_output_str == vllm_output_str, (
121+
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
122+
assert hf_output_ids == vllm_output_ids, (
123+
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

tests/multimodal/test_processor.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import pytest
3-
from transformers import CLIPImageProcessor
3+
from transformers import CLIPImageProcessor, LlavaNextImageProcessor
44

55
from vllm.config import ModelConfig, VisionLanguageConfig
66
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -12,7 +12,7 @@
1212
@pytest.mark.parametrize("dtype", ["half", "float"])
1313
def test_clip_image_processor(hf_images, dtype):
1414
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
15-
IMAGE_HEIGHT = IMAGE_WIDTH = 33
15+
IMAGE_HEIGHT = IMAGE_WIDTH = 560
1616

1717
hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
1818
assert isinstance(hf_processor, CLIPImageProcessor)
@@ -55,10 +55,61 @@ def test_clip_image_processor(hf_images, dtype):
5555
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
5656

5757

58+
@pytest.mark.xfail(
59+
reason="Inconsistent image processor being used due to lack "
60+
"of support for dynamic image token replacement")
61+
@pytest.mark.parametrize("dtype", ["half", "float"])
62+
def test_llava_next_image_processor(hf_images, dtype):
63+
MODEL_NAME = "llava-hf/llava-v1.6-34b-hf"
64+
IMAGE_HEIGHT = IMAGE_WIDTH = 560
65+
66+
hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME)
67+
assert isinstance(hf_processor, LlavaNextImageProcessor)
68+
69+
model_config = ModelConfig(
70+
model=MODEL_NAME,
71+
tokenizer=MODEL_NAME,
72+
tokenizer_mode="auto",
73+
trust_remote_code=False,
74+
seed=0,
75+
dtype=dtype,
76+
revision=None,
77+
)
78+
vlm_config = VisionLanguageConfig(
79+
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
80+
image_token_id=64000,
81+
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
82+
image_feature_size=2928,
83+
image_processor=MODEL_NAME,
84+
image_processor_revision=None,
85+
)
86+
87+
for image in hf_images:
88+
hf_result = hf_processor.preprocess(
89+
image,
90+
return_tensors="pt",
91+
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
92+
vllm_result = MULTIMODAL_REGISTRY.process_input(
93+
ImagePixelData(image),
94+
model_config=model_config,
95+
vlm_config=vlm_config,
96+
)
97+
98+
assert hf_result.keys() == vllm_result.keys()
99+
for key, hf_tensor in hf_result.items():
100+
hf_arr: np.ndarray = hf_tensor.numpy()
101+
vllm_arr: np.ndarray = vllm_result[key].numpy()
102+
103+
assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}"
104+
assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}"
105+
106+
107+
@pytest.mark.xfail(
108+
reason="Example image pixels were not processed using HuggingFace")
58109
@pytest.mark.parametrize("dtype", ["float"])
59110
def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
60111
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
61-
IMAGE_HEIGHT = IMAGE_WIDTH = 33
112+
IMAGE_HEIGHT = IMAGE_WIDTH = 560
62113

63114
model_config = ModelConfig(
64115
model=MODEL_NAME,
@@ -95,7 +146,4 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype):
95146
tensor_arr: np.ndarray = tensor_result[key].numpy()
96147

97148
assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}"
98-
99-
# The examples in PR#3042 have slightly different preprocessing from
100-
# HuggingFace's LlavaProcessor, causing the test to fail.
101-
# assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"
149+
assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}"

vllm/model_executor/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
3434
"LlavaForConditionalGeneration":
3535
("llava", "LlavaForConditionalGeneration"),
36+
"LlavaNextForConditionalGeneration":
37+
("llava_next", "LlavaNextForConditionalGeneration"),
3638
# For decapoda-research/llama-*
3739
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
3840
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),

vllm/model_executor/models/llava.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
22

33
import torch
4-
from torch import nn
4+
import torch.nn as nn
55
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
66
# transformers' impl.
77
from transformers import CLIPVisionModel, LlavaConfig
@@ -51,10 +51,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
5151
return hidden_states
5252

5353

54-
def _merge_vision_embeddings(input_ids: torch.Tensor,
55-
inputs_embeds: torch.Tensor,
56-
vision_embeddings: torch.Tensor,
57-
image_token_id: int) -> torch.Tensor:
54+
def merge_vision_embeddings(input_ids: torch.Tensor,
55+
inputs_embeds: torch.Tensor,
56+
vision_embeddings: torch.Tensor,
57+
image_token_id: int) -> torch.Tensor:
5858
"""In place merges in vision_embeddings with inputs_embeds."""
5959
mask = (input_ids == image_token_id)
6060

@@ -151,7 +151,8 @@ def _parse_and_validate_image_input(
151151
return None
152152

153153
if not isinstance(pixel_values, torch.Tensor):
154-
raise ValueError("Incorrect type of pixel values")
154+
raise ValueError("Incorrect type of pixel values. "
155+
f"Got type: {type(pixel_values)}")
155156

156157
return LlavaImagePixelInputs(
157158
type="pixel_values",
@@ -166,7 +167,8 @@ def _parse_and_validate_image_input(
166167
return None
167168

168169
if not isinstance(image_features, torch.Tensor):
169-
raise ValueError("Incorrect type of image features")
170+
raise ValueError("Incorrect type of image features. "
171+
f"Got type: {type(image_features)}")
170172

171173
return LlavaImageFeatureInputs(
172174
type="image_features",
@@ -268,7 +270,7 @@ def forward(
268270
vision_embeddings = self._process_image_input(image_input)
269271
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
270272

271-
inputs_embeds = _merge_vision_embeddings(
273+
inputs_embeds = merge_vision_embeddings(
272274
input_ids, inputs_embeds, vision_embeddings,
273275
self.vision_language_config.image_token_id)
274276

0 commit comments

Comments
 (0)