diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 52afda747aab8..b917688a529d1 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -3,7 +3,7 @@ Using VLMs ========== -This document shows you how to run and serve Vision Language Models (VLMs) using vLLM. +vLLM provides experimental support for Vision Language Models (VLMs). This document shows you how to run and serve these models using vLLM. Engine Arguments ---------------- @@ -54,3 +54,69 @@ For now, we only support a single image per text prompt. To pass an image to the print(generated_text) A code example can be found in `examples/llava_example.py `_. + +Online OpenAI Vision API Compatible Inference +---------------------------------------------- + +You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API `_. + +.. note:: + Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be + added in the future. + +Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server. + +.. important:: + Since OpenAI Vision API is based on `Chat `_ API, a chat template + is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the + HuggingFace Llava chat template that you can find in the example folder `here `_. + +.. code-block:: bash + + python -m vllm.entrypoints.openai.api_server \ + --model llava-hf/llava-1.5-7b-hf \ + --image-input-type pixel_values \ + --image-token-id 32000 \ + --image-input-shape 1,3,336,336 \ + --image-feature-size 576 \ + --chat-template template_llava.jinja + +To consume the server, you can use the OpenAI client like in the example below: + +.. code-block:: python + + from openai import OpenAI + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + chat_response = client.chat.completions.create( + model="llava-hf/llava-1.5-7b-hf", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + }, + ], + }], + ) + print("Chat response:", chat_response) + +.. note:: + + By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable: + + .. code-block:: shell + + export VLLM_IMAGE_FETCH_TIMEOUT= + +.. note:: + The prompt formatting with the image token ```` is not needed when serving VLMs with the API server since the prompt will be + processed automatically by the server. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a912949352b86..6248d84683753 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -30,6 +30,8 @@ Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-refer - Chat: `tools`, and `tool_choice`. - Completions: `suffix`. +vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst). + ## Extra Parameters vLLM supports a set of parameters that are not part of the OpenAI API. In order to use them, you can pass them as extra parameters in the OpenAI client. @@ -120,4 +122,4 @@ It is the callers responsibility to prompt the model with the tool information, vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. -Please refer to the OpenAI API reference documentation for more information. \ No newline at end of file +Please refer to the OpenAI API reference documentation for more information. diff --git a/examples/template_llava.jinja b/examples/template_llava.jinja new file mode 100644 index 0000000000000..6a902ee167725 --- /dev/null +++ b/examples/template_llava.jinja @@ -0,0 +1,23 @@ +{%- if messages[0]['role'] == 'system' -%} + {%- set system_message = messages[0]['content'] -%} + {%- set messages = messages[1:] -%} +{%- else -%} + {% set system_message = '' -%} +{%- endif -%} + +{{ bos_token + system_message }} +{%- for message in messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif -%} + + {%- if message['role'] == 'user' -%} + {{ 'USER: ' + message['content'] + '\n' }} + {%- elif message['role'] == 'assistant' -%} + {{ 'ASSISTANT: ' + message['content'] + eos_token + '\n' }} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {{ 'ASSISTANT:' }} +{% endif %} diff --git a/tests/entrypoints/test_openai_vision.py b/tests/entrypoints/test_openai_vision.py new file mode 100644 index 0000000000000..cc03b04e0b0e0 --- /dev/null +++ b/tests/entrypoints/test_openai_vision.py @@ -0,0 +1,286 @@ +from pathlib import Path +from typing import Dict + +import openai +import pytest +import pytest_asyncio +import ray + +from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64 + +from ..utils import ServerRunner + +MODEL_NAME = "llava-hf/llava-1.5-7b-hf" +LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent / + "examples/template_llava.jinja") +assert LLAVA_CHAT_TEMPLATE.exists() +# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) +TEST_IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +] + +pytestmark = pytest.mark.openai + + +@pytest.fixture(scope="module") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", + "--image-input-type", + "pixel_values", + "--image-token-id", + "32000", + "--image-input-shape", + "1,3,336,336", + "--image-feature-size", + "576", + "--chat-template", + str(LLAVA_CHAT_TEMPLATE), + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest_asyncio.fixture(scope="session") +async def base64_encoded_image() -> Dict[str, str]: + return { + image_url: + encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image(server, client: openai.AsyncOpenAI, + model_name: str, image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=596, total_tokens=606) + + message = choice.message + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image_base64encoded( + server, client: openai.AsyncOpenAI, model_name: str, image_url: str, + base64_encoded_image: Dict[str, str]): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": + f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=596, total_tokens=606) + + message = choice.message + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_chat_streaming_image(server, client: openai.AsyncOpenAI, + model_name: str, image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + output = chat_completion.choices[0].message.content + stop_reason = chat_completion.choices[0].finish_reason + + # test streaming + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=True, + ) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + if delta.content: + chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert delta.content + assert "".join(chunks) == output + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_multi_image_input(server, client: openai.AsyncOpenAI, + model_name: str, image_url: str): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + with pytest.raises(openai.BadRequestError): # test multi-image input + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py new file mode 100644 index 0000000000000..5a6395ac9e42a --- /dev/null +++ b/tests/multimodal/test_utils.py @@ -0,0 +1,75 @@ +import base64 +import mimetypes +from tempfile import NamedTemporaryFile +from typing import Dict, Tuple + +import numpy as np +import pytest +import pytest_asyncio +from PIL import Image + +from vllm.multimodal.utils import ImageFetchAiohttp + +# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) +TEST_IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +] + + +@pytest_asyncio.fixture(scope="session") +async def url_images() -> Dict[str, Image.Image]: + return { + image_url: await ImageFetchAiohttp.fetch_image(image_url) + for image_url in TEST_IMAGE_URLS + } + + +def get_supported_suffixes() -> Tuple[str, ...]: + # We should at least test the file types mentioned in GPT-4 with Vision + OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif') + + # Additional file types that are supported by us + EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff') + + return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES + + +def _image_equals(a: Image.Image, b: Image.Image) -> bool: + return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("suffix", get_supported_suffixes()) +async def test_fetch_image_base64(url_images: Dict[str, Image.Image], + image_url: str, suffix: str): + url_image = url_images[image_url] + + try: + mime_type = Image.MIME[Image.registered_extensions()[suffix]] + except KeyError: + try: + mime_type = mimetypes.types_map[suffix] + except KeyError: + pytest.skip('No MIME type') + + with NamedTemporaryFile(suffix=suffix) as f: + try: + url_image.save(f.name) + except Exception as e: + if e.args[0] == 'cannot write mode RGBA as JPEG': + pytest.skip('Conversion not supported') + + raise + + base64_image = base64.b64encode(f.read()).decode("utf-8") + data_url = f"data:{mime_type};base64,{base64_image}" + + data_image = await ImageFetchAiohttp.fetch_image(data_url) + if _image_equals(url_image, Image.open(f)): + assert _image_equals(url_image, data_image) + else: + pass # Lossy format; only check that image can be opened diff --git a/vllm/config.py b/vllm/config.py index 4efdb6cab52c4..a980168190adc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -5,7 +5,7 @@ Union) import torch -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedTokenizerBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -1119,6 +1119,16 @@ def get_image_input_enum_type(cls, value: str) -> ImageInputType: f"Expecting to choose from " f"{[x.name for x in cls.ImageInputType]}.") from e + #TODO(ywang96): make this a cached property once we refactor the + # VisionLanguageConfig class. + def get_image_token_text( + self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: + """Get the image token placeholder text to be inserted into the + text prompt and the string representation of the image token id. + """ + image_token_str = tokenizer.decode(self.image_token_id) + return image_token_str * self.image_feature_size, image_token_str + def as_cli_args_dict(self) -> Dict[str, Any]: """Flatten vision language config to pure args. diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 883567abf415b..c025e7e96826c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,15 +1,16 @@ import codecs import time -from dataclasses import dataclass -from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, - Optional) +from dataclasses import dataclass, field +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, + List, Optional) from typing import Sequence as GenericSequence from typing import TypedDict, Union, cast, final from fastapi import Request -from openai.types.chat import ChatCompletionContentPartTextParam +from openai.types.chat import (ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam) -from vllm.config import ModelConfig +from vllm.config import ModelConfig, VisionLanguageConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionContentPartParam, ChatCompletionLogProb, @@ -21,9 +22,13 @@ FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) +from vllm.multimodal.image import ImagePixelData +from vllm.multimodal.utils import (async_get_and_parse_image, + get_full_image_text_prompt) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.utils import random_uuid @@ -40,6 +45,8 @@ class ConversationMessage(TypedDict): @dataclass(frozen=True) class ChatMessageParseResult: messages: List[ConversationMessage] + image_futures: List[Awaitable[ImagePixelData]] = field( + default_factory=list) class OpenAIServingChat(OpenAIServing): @@ -94,19 +101,76 @@ def _parse_chat_message_content_parts( parts: Iterable[ChatCompletionContentPartParam], ) -> ChatMessageParseResult: texts: List[str] = [] + image_futures: List[Awaitable[ImagePixelData]] = [] - for _, part in enumerate(parts): + vlm_config: Optional[VisionLanguageConfig] = getattr( + self.engine.engine, "vision_language_config", None) + model_config = getattr(self.engine.engine, "model_config", None) + + for part in parts: part_type = part["type"] if part_type == "text": text = cast(ChatCompletionContentPartTextParam, part)["text"] texts.append(text) + elif part_type == "image_url": + if vlm_config is None: + raise ValueError( + "'image_url' input is not supported as the loaded " + "model is not multimodal.") + + elif len(image_futures) == 0: + assert self.tokenizer is not None + image_url = cast(ChatCompletionContentPartImageParam, + part)["image_url"] + + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") + + image_future = async_get_and_parse_image(image_url["url"]) + image_futures.append(image_future) + + else: + raise NotImplementedError( + "Multiple 'image_url' input is currently not supported." + ) + else: raise NotImplementedError(f"Unknown part type: {part_type}") - messages = [ConversationMessage(role=role, content="\n".join(texts))] + text_prompt = "\n".join(texts) + + if vlm_config is not None and len(image_futures): + + (image_token_prompt, + image_token_str) = vlm_config.get_image_token_text(self.tokenizer) - return ChatMessageParseResult(messages=messages) + # NOTE: If image token string (e.g, ) is already present + # in the text prompt, we assume it follows the same format required + # by the engine. + if image_token_str in text_prompt: + logger.warning( + "Detected image token string in the text prompt. " + "Skipping prompt formatting.") + messages = [ + ConversationMessage(role=role, content=text_prompt) + ] + + else: + full_prompt = get_full_image_text_prompt( + image_prompt=image_token_prompt, + text_prompt=text_prompt, + config=model_config) + messages = [ + ConversationMessage(role=role, content=full_prompt) + ] + else: + messages = [ConversationMessage(role=role, content=text_prompt)] + + return ChatMessageParseResult(messages=messages, + image_futures=image_futures) def _parse_chat_message_content( self, @@ -116,10 +180,10 @@ def _parse_chat_message_content( content = message.get("content") if content is None: - return ChatMessageParseResult(messages=[]) + return ChatMessageParseResult(messages=[], image_futures=[]) if isinstance(content, str): messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages) + return ChatMessageParseResult(messages=messages, image_futures=[]) return self._parse_chat_message_content_parts(role, content) @@ -144,11 +208,13 @@ async def create_chat_completion( try: conversation: List[ConversationMessage] = [] + image_futures: List[Awaitable[ImagePixelData]] = [] for msg in request.messages: - parsed_msg = self._parse_chat_message_content(msg) + chat_parsed_result = self._parse_chat_message_content(msg) - conversation.extend(parsed_msg.messages) + conversation.extend(chat_parsed_result.messages) + image_futures.extend(chat_parsed_result.image_futures) prompt = self.tokenizer.apply_chat_template( conversation=conversation, @@ -159,6 +225,17 @@ async def create_chat_completion( logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) + # Fetch image data + image_data: Optional[ImagePixelData] = None + try: + if len(image_futures): + # since we support only single image currently + assert len(image_futures) == 1 + image_data = await image_futures[0] + except Exception as e: + logger.error("Error in loading image data: %s", e) + return self.create_error_response(str(e)) + request_id = f"cmpl-{random_uuid()}" try: # Tokenize/detokenize depending on prompt format (string/token list) @@ -183,11 +260,15 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) + inputs: PromptInputs = { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids, + } + if image_data is not None: + inputs["multi_modal_data"] = image_data + result_generator = self.engine.generate( - { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids - }, + inputs, sampling_params, request_id, lora_request, diff --git a/vllm/envs.py b/vllm/envs.py index 7d5c7371b7741..b140aa6d658e6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -29,6 +29,7 @@ VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None @@ -216,6 +217,11 @@ # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), + + # Timeout for fetching images when serving multimodal models + # Default is 5 seconds + "VLLM_IMAGE_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), } # end-env-vars-definition diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py new file mode 100644 index 0000000000000..b8ad6f8f78e26 --- /dev/null +++ b/vllm/multimodal/utils.py @@ -0,0 +1,85 @@ +import base64 +from io import BytesIO +from typing import Optional, Union + +import aiohttp +from PIL import Image + +from vllm.config import ModelConfig +from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT +from vllm.multimodal.image import ImagePixelData + + +class ImageFetchAiohttp: + aiohttp_client: Optional[aiohttp.ClientSession] = None + + @classmethod + def get_aiohttp_client(cls) -> aiohttp.ClientSession: + if cls.aiohttp_client is None: + timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT) + connector = aiohttp.TCPConnector() + cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout, + connector=connector) + + return cls.aiohttp_client + + @classmethod + async def fetch_image(cls, image_url: str) -> Image.Image: + """Load PIL image from a url or base64 encoded openai GPT4V format""" + + if image_url.startswith('http'): + # Avoid circular import + from vllm import __version__ as VLLM_VERSION + + client = cls.get_aiohttp_client() + headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"} + + async with client.get(url=image_url, headers=headers) as response: + response.raise_for_status() + image_raw = await response.read() + image = Image.open(BytesIO(image_raw)) + + # Only split once and assume the second part is the base64 encoded image + elif image_url.startswith('data:image'): + image = load_image_from_base64(image_url.split(',', 1)[1]) + + else: + raise ValueError("Invalid image url: A valid image url must start " + "with either 'data:image' or 'http'.") + + return image + + +async def async_get_and_parse_image(image_url: str) -> ImagePixelData: + with await ImageFetchAiohttp.fetch_image(image_url) as image: + return ImagePixelData(image) + + +def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: + """encode image to base64 format.""" + + buffered = BytesIO() + if format == 'JPEG': + image = image.convert('RGB') + image.save(buffered, format) + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """Load image from base64 format.""" + return Image.open(BytesIO(base64.b64decode(image))) + + +# TODO(ywang96): move this to a model registry for preprocessing vision +# language prompts based on the model type. +def get_full_image_text_prompt(image_prompt: str, text_prompt: str, + config: ModelConfig) -> str: + """Combine image and text prompts for vision language model depending on + the model architecture.""" + + if config.hf_config.model_type == "llava": + full_prompt = f"{image_prompt}\n{text_prompt}" + else: + raise ValueError( + f"Unsupported model type: {config.hf_config.model_type}") + return full_prompt