Skip to content

Commit

Permalink
Support image processor
Browse files Browse the repository at this point in the history
- Also add docs for basic VLM usage
  • Loading branch information
DarkLight1337 committed Apr 19, 2024
1 parent 221d93e commit adf2b94
Show file tree
Hide file tree
Showing 19 changed files with 592 additions and 191 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- LLavA-1.5 (`llava-hf/llava-1.5-7b-hf`, `llava-hf/llava-1.5-13b-hf`, etc.)
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Documentation
models/adding_model
models/engine_args
models/lora
models/vlm

.. toctree::
:maxdepth: 1
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ Alongside each architecture, we include some popular models that use it.
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
* - :code:`LlavaForConditionalGeneration`
- LLaVA-1.5
- :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc.
-
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
Expand Down
77 changes: 77 additions & 0 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
.. _vlm:

Using VLMs
==========

This document shows you how to run and serve Vision Language Models (VLMs) using vLLM.

Additional Engine Arguments
---------------------------

Apart from the :ref:`basic engine arguments <engine_args>`, VLMs additionally require the following engine arguments for vLLM.

.. option:: --image-input-type {pixel_values,image_features}

The image input type passed into vLLM. Should be one of "pixel_values" or "image_features".

.. option:: --image-token-id <id>

Input ID for image token.

.. option:: --image-input-shape <tuple>

The biggest image input shape (worst for memory footprint) given an input type. Only used for vLLM's profile_run.

For example, if the image tensor has shape :code:`(1, 3, 336, 336)`, then you should pass :code:`--image-input-shape 1,3,336,336`.

.. option:: --image-feature-size <size>

The image feature size along the context dimension.

.. option:: --image-processor <size>

Name or path of the huggingface image processor to use.

.. option:: --image-processor-revision <revision>

The specific image processor version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.

.. option:: --no-image-processor

Disables the use of image processor, even if one is defined for the model on huggingface.

Offline Batched Inference
-------------------------

To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine.

.. code-block:: python
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
)
For now, we only support a single image per text prompt when calling ``llm.generate``. To pass an image to the model, note the following parameters:

* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
* ``multi_modal_datas``: This should be an instance of ``ImagePixelData``.

.. code-block:: python
prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")
# Load the image using PIL.Image
image = ...
outputs = llm.generate(prompt, multi_modal_datas=ImagePixelData(image))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
A code example can be found in `examples/llava_example.py <https://github.com/vllm-project/vllm/blob/main/examples/llava_example.py>`_.
15 changes: 6 additions & 9 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import subprocess

import torch
from PIL import Image

from vllm import LLM
from vllm.sequence import MultiModalData
from vllm.sequence import ImageFeatureData, ImagePixelData

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.

Expand All @@ -23,11 +24,9 @@ def run_llava_pixel_values():
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt")
image = Image.open("images/stop_sign.jpg")

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
outputs = llm.generate(prompt, multi_modal_datas=ImagePixelData(image))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Expand All @@ -46,11 +45,9 @@ def run_llava_image_features():
"\nUSER: What is the content of this image?\nASSISTANT:")

# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt")
image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
outputs = llm.generate(prompt, multi_modal_datas=ImageFeatureData(image))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pillow # Required for image processing
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.3
Expand Down
3 changes: 0 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,3 @@ ai2-olmo # required for OLMo

# Benchmarking
aiohttp

# Multimodal
pillow
56 changes: 28 additions & 28 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,14 @@
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.sequence import MultiModalData
from vllm.sequence import ImageFeatureData, ImagePixelData, MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer

_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]

# Multi modal related
_PIXEL_VALUES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
]
_IMAGE_FEATURES_FILES = [
os.path.join(_TEST_DIR, "images", filename) for filename in
["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
Expand All @@ -36,8 +32,7 @@
"<image>\nUSER: What's the content of the image?\nASSISTANT:",
"<image>\nUSER: What is the season?\nASSISTANT:"
]
assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
_IMAGE_FILES) == len(_IMAGE_PROMPTS)
assert len(_IMAGE_FEATURES_FILES) == len(_IMAGE_FILES) == len(_IMAGE_PROMPTS)


def _read_prompts(filename: str) -> List[str]:
Expand Down Expand Up @@ -85,17 +80,18 @@ def hf_images() -> List[Image.Image]:


@pytest.fixture()
def vllm_images(request) -> "torch.Tensor":
def vllm_images(request) -> List[MultiModalData]:
vision_language_config = request.getfixturevalue("model_and_config")[1]
all_images = []
if vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
filenames = _IMAGE_FEATURES_FILES
return [
ImageFeatureData(torch.load(filename))
for filename in _IMAGE_FEATURES_FILES
]
else:
filenames = _PIXEL_VALUES_FILES
for filename in filenames:
all_images.append(torch.load(filename))
return torch.concat(all_images, dim=0)
return [
ImagePixelData(Image.open(filename)) for filename in _IMAGE_FILES
]


@pytest.fixture()
Expand Down Expand Up @@ -172,15 +168,17 @@ def generate(
images: Optional[List[Image.Image]] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
if images:
if images is not None:
assert len(prompts) == len(images)

outputs: List[Tuple[List[int], str]] = []
for i, prompt in enumerate(prompts):
if self.model_name not in _VISION_LANGUAGE_MODELS:
input_ids = self.tokenizer(prompt,
return_tensors="pt").input_ids
inputs = {"input_ids": input_ids.cuda()}
else:
assert self.processor is not None
image = images[i] if images else None
inputs = self.processor(text=prompt,
images=image,
Expand All @@ -189,6 +187,7 @@ def generate(
key: value.cuda() if value is not None else None
for key, value in inputs.items()
}

output_ids = self.model.generate(
**inputs,
use_cache=True,
Expand All @@ -207,7 +206,7 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional["torch.Tensor"] = None,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
Expand Down Expand Up @@ -316,16 +315,15 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional["torch.Tensor"] = None,
multi_modal_datas: Optional[List[Optional[MultiModalData]]] = None,
) -> List[Tuple[List[int], str]]:
if images is not None:
assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts,
sampling_params=sampling_params,
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
data=images)
if images is not None else None)
if multi_modal_datas is not None:
assert len(prompts) == len(multi_modal_datas)

req_outputs = self.model.generate(prompts,
sampling_params=sampling_params,
multi_modal_datas=multi_modal_datas)

outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
Expand Down Expand Up @@ -362,10 +360,12 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[torch.Tensor] = None,
multi_modal_datas: Optional[List[Optional[MultiModalData]]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
outputs = self.generate(prompts,
greedy_params,
multi_modal_datas=multi_modal_datas)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand Down
Loading

0 comments on commit adf2b94

Please sign in to comment.