Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM][Model] Support image input for Chameleon #6633

Merged
merged 14 commits into from
Jul 23, 2024
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ Vision Language Models
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`ChameleonForConditionalGeneration`
- Chameleon
- :code:`facebook/chameleon-7b` etc.
-
* - :code:`FuyuForCausalLM`
- Fuyu
- :code:`adept/fuyu-8b` etc.
Expand Down
102 changes: 102 additions & 0 deletions tests/models/test_chameleon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import re
from typing import List, Optional, Type

import pytest

from vllm.multimodal.utils import rescale_image_size

from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets

pytestmark = pytest.mark.vlm

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
"cherry_blossom":
"USER: <image>\nWhat is the season?\nASSISTANT:",
})

models = ["facebook/chameleon-7b"]


#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def run_test(
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Test if the model can generate text given
a batch of images and prompts.

"""
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

with vllm_runner(model,
max_model_len=4096,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:

for prompts, images in inputs_per_image:
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
images=images)
for i in range(len(vllm_outputs)):

# format prompt back to original
replacements = {
"<racm3:break>": "",
"<eoss>": "",
"<reserved08706>": ""
}
pattern = '|'.join(replacements.keys())
vllm_result = re.sub(
pattern,
lambda match: replacements[match.group(0)], #noqa B023
vllm_outputs[i][1])
vllm_result = vllm_result.replace("<image>", "", 1023)
assert vllm_result[:len(prompts[i])] == prompts[i]

# assert at least 10 new characters are generated
# (to take stop token into account)
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
max_tokens: int) -> None:
run_test(
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
tensor_parallel_size=1,
)
3 changes: 2 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig,
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)

if model_type == "chameleon":
return "<image>"
raise TypeError("Unknown model type: {model_type}")


Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChameleonForCausalLM":
("chameleon", "ChameleonForConditionalGeneration"
), #TODO(ywang96): fix model name when huggingface fixes it
#TODO(ywang96): remove this when huggingface fixes the model repo
"ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
Expand Down
Loading
Loading