Skip to content

Commit

Permalink
[VLM][Model] Support image input for Chameleon (vllm-project#6633)
Browse files Browse the repository at this point in the history
  • Loading branch information
ywang96 authored Jul 23, 2024
1 parent c520124 commit 22fa2e3
Show file tree
Hide file tree
Showing 7 changed files with 696 additions and 58 deletions.
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ Vision Language Models
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`ChameleonForConditionalGeneration`
- Chameleon
- :code:`facebook/chameleon-7b` etc.
-
* - :code:`FuyuForCausalLM`
- Fuyu
- :code:`adept/fuyu-8b` etc.
Expand Down
102 changes: 102 additions & 0 deletions tests/models/test_chameleon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import re
from typing import List, Optional, Type

import pytest

from vllm.multimodal.utils import rescale_image_size

from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets

pytestmark = pytest.mark.vlm

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom":
"USER: <image>\nWhat is the season?\nASSISTANT:",
})

models = ["facebook/chameleon-7b"]


#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def run_test(
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Test if the model can generate text given
a batch of images and prompts.
"""
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

with vllm_runner(model,
max_model_len=4096,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:

for prompts, images in inputs_per_image:
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
images=images)
for i in range(len(vllm_outputs)):

# format prompt back to original
replacements = {
"<racm3:break>": "",
"<eoss>": "",
"<reserved08706>": ""
}
pattern = '|'.join(replacements.keys())
vllm_result = re.sub(
pattern,
lambda match: replacements[match.group(0)], #noqa B023
vllm_outputs[i][1])
vllm_result = vllm_result.replace("<image>", "", 1023)
assert vllm_result[:len(prompts[i])] == prompts[i]

# assert at least 10 new characters are generated
# (to take stop token into account)
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
max_tokens: int) -> None:
run_test(
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=1,
)
3 changes: 2 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig,
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)

if model_type == "chameleon":
return "<image>"
raise TypeError("Unknown model type: {model_type}")


Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChameleonForCausalLM":
("chameleon", "ChameleonForConditionalGeneration"
), #TODO(ywang96): fix model name when huggingface fixes it
#TODO(ywang96): remove this when huggingface fixes the model repo
"ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
Expand Down
Loading

0 comments on commit 22fa2e3

Please sign in to comment.