Skip to content

Commit

Permalink
Fix LLaVA example and test w.r.t. image processing refactor
Browse files Browse the repository at this point in the history
- Note that we now load the images directly instead of from `.pt` files
  • Loading branch information
DarkLight1337 committed Apr 19, 2024
1 parent 91ea044 commit cb19743
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 51 deletions.
6 changes: 1 addition & 5 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@ def run_llava_pixel_values():
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
no_image_processor=True,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
image_tensor: torch.Tensor = torch.load("images/stop_sign_pixel_values.pt")
image_arr = image_tensor.view(3, 336, 336).permute((1, 2, 0)).numpy()
image = Image.fromarray(image_arr, mode="RGB")
image = Image.open("images/stop_sign.jpg")

outputs = llm.generate(prompt, multi_modal_datas=ImagePixelData(image))
for o in outputs:
Expand All @@ -42,7 +39,6 @@ def run_llava_image_features():
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
no_image_processor=True,
)

prompt = "<image>" * 576 + (
Expand Down
22 changes: 10 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,14 @@
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.sequence import MultiModalData
from vllm.sequence import ImageFeatureData, ImagePixelData, MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer

_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]

# Multi modal related
_PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
_IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
Expand All @@ -37,8 +33,7 @@
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
_IMAGE_FILES) == len(_IMAGE_PROMPTS)
assert len(_IMAGE_FEATURES_FILES) == len(_IMAGE_FILES) == len(_IMAGE_PROMPTS)


def _read_prompts(filename: str) -> List[str]:
Expand Down Expand Up @@ -86,15 +81,18 @@ def hf_images() -> List[Image.Image]:


@pytest.fixture()
def vllm_images(request) -> List[torch.Tensor]:
def vllm_images(request) -> List[MultiModalData]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
filenames = _IMAGE_FEATURES_FILES
return [
ImageFeatureData(torch.load(filename))
for filename in _IMAGE_FEATURES_FILES
]
else:
filenames = _PIXEL_VALUES_FILES

return [torch.load(filename) for filename in filenames]
return [
ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES
]


@pytest.fixture()
Expand Down
93 changes: 59 additions & 34 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,47 @@
from transformers import AutoTokenizer

from vllm.config import VisionLanguageConfig
from vllm.sequence import ImagePixelData


def iter_llava_configs(model_name: str):
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, 336, 336)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, 576, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=576,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=None,
image_processor_revision=None))
image_hw_to_feature_size = {
(336, 336): 576,
}

for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))


def iter_llava_next_configs(model_name: str):
for input_type, input_shape in [
# `vision_config` on HuggingFace only supports `image_size=336`
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, 336, 336)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, 576, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=576,
image_token_id=64000,
image_input_shape=input_shape,
image_processor=None,
image_processor_revision=None))
image_hw_to_feature_size = {
(336, 336): 1176,
(672, 672): 2928,
(1344, 336): 1944,
(336, 1344): 1890,
}

for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=64000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))


model_and_vl_config = [
Expand Down Expand Up @@ -99,37 +110,51 @@ def test_models(hf_runner, vllm_runner, hf_image_prompts, hf_images,
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the raw images as input.
For vllm runner, we provide image tensors and corresponding
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vision_language_config = model_and_config

hf_model = hf_runner(model_id, dtype=dtype)
hf_outputs = hf_model.generate_greedy(hf_image_prompts,
max_tokens,
images=hf_images)
_, vision_language_config = model_and_config
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
# HuggingFace does not support image feature input
hf_outputs = [None] * len(hf_image_prompts)
else:
_, _, h, w = vision_language_config.image_input_shape
hf_outputs = hf_model.generate_greedy(
hf_image_prompts,
max_tokens,
# To be compatible with the patch for LLaVA-NeXT
images=[im.resize((w, h)) for im in hf_images])
del hf_model

vllm_model = vllm_runner(model_id,
dtype=dtype,
worker_use_ray=worker_use_ray,
enforce_eager=True,
**as_dict(vision_language_config))
vllm_outputs = vllm_model.generate_greedy(
vllm_image_prompts,
max_tokens,
multi_modal_datas=[ImagePixelData(image) for image in vllm_images])
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
multi_modal_datas=vllm_images)
del vllm_model

gc.collect()
torch.cuda.empty_cache()

for i in range(len(hf_image_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
hf_output = hf_outputs[i]
if hf_output is None:
continue

hf_output_ids, hf_output_str = hf_output
vllm_output_ids, vllm_output_str = sanitize_vllm_output(
vllm_outputs[i], vision_language_config, model_id)
print(f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
Expand Down

0 comments on commit cb19743

Please sign in to comment.