Skip to content

[Frontend] support image embeds #13955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion docs/source/serving/multimodal_inputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -462,4 +462,69 @@ export VLLM_AUDIO_FETCH_TIMEOUT=<timeout>

### Embedding Inputs

TBD
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
pass a tensor of shape to the corresponding field of the multi-modal dictionary.
#### Image Embedding Inputs
For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field.
The following example demonstrates how to pass image embeddings to the OpenAI server:

```python
image_embedding = torch.load(...)
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct

buffer = io.BytesIO()
torch.save(image_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

# Basic usage - this is equivalent to the LLaVA example for offline inference
model = "llava-hf/llava-1.5-7b-hf"
embeds = {
"type": "image_embeds",
"image_embeds": f"{base64_image_embedding}"
}

# Pass additional parameters (available to Qwen2-VL and MiniCPM-V)
model = "Qwen/Qwen2-VL-2B-Instruct"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}" , # Required
"image_grid_thw": f"{base64_image_grid_thw}" # Required by Qwen/Qwen2-VL-2B-Instruct
},
}
model = "openbmb/MiniCPM-V-2_6"
embeds = {
"type": "image_embeds",
"image_embeds": {
"image_embeds": f"{base64_image_embedding}" , # Required
"image_sizes": f"{base64_image_sizes}" # Required by openbmb/MiniCPM-V-2_6
},
}
chat_completion = client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [
{
"type": "text",
"text": "What's in this image?",
},
embeds,
],
},
],
model=model,
)
```

:::{note}
Only one message can contain `{"type": "image_embeds"}`.
If used with a model that requires additional parameters, you must also provide a tensor for each of them, e.g. `image_grid_thw`, `image_sizes`, etc.
:::
113 changes: 103 additions & 10 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
"""The type of the content part."""


class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
image_embeds: Required[Union[str, dict[str, str]]]
"""
The image embeddings. It can be either:
- A single base64 string.
- A dictionary where each value is a base64 string.
"""
type: Required[Literal["image_embeds"]]
"""The type of the content part."""


class VideoURL(TypedDict, total=False):
url: Required[str]
"""
Expand Down Expand Up @@ -109,6 +120,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str]

Expand Down Expand Up @@ -350,7 +362,7 @@ def resolve_chat_template_content_format(
return detected_format


ModalityStr = Literal["image", "audio", "video"]
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
_T = TypeVar("_T")


Expand Down Expand Up @@ -391,7 +403,7 @@ def _placeholder_str(self, modality: ModalityStr,
hf_config = self._model_config.hf_config
model_type = hf_config.model_type

if modality == "image":
if modality in ["image", "image_embeds"]:
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
Expand Down Expand Up @@ -470,10 +482,27 @@ def create_parser(self) -> "BaseMultiModalContentParser":
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):

def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items_by_modality:
return dict(self._items_by_modality)

return None
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(\
"Mixing raw image and embedding inputs is not allowed")

if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(\
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs

def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
Expand All @@ -482,13 +511,31 @@ def create_parser(self) -> "BaseMultiModalContentParser":
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):

async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items_by_modality:
return {
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}

return None
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed")

if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}")
mm_inputs["image"] = image_embeds_lst[0]
elif "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
elif "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
elif "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs

def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)
Expand All @@ -513,6 +560,11 @@ def mm_placeholder_counts(self) -> dict[str, int]:
def parse_image(self, image_url: str) -> None:
raise NotImplementedError

@abstractmethod
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
raise NotImplementedError

@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -543,6 +595,21 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
placeholder = self._tracker.add("image_embeds", embeds)

if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding)

self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio = self._connector.fetch_audio(audio_url)

Expand Down Expand Up @@ -579,6 +646,25 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)

def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()

if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items()
}
future.set_result(embeds)

if isinstance(image_embeds, str):
embedding = self._connector.\
fetch_image_embedding(image_embeds)
future.set_result(embedding)

placeholder = self._tracker.add("image_embeds", future)
self._add_placeholder(placeholder)

def parse_audio(self, audio_url: str) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url)

Expand Down Expand Up @@ -684,6 +770,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
Expand All @@ -700,6 +787,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
lambda part: _TextParser(part).get("text", ""),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", ""),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", {}),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""),
"input_audio":
Expand Down Expand Up @@ -769,6 +858,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"audio_url", "input_audio", "video_url")


Expand Down Expand Up @@ -843,7 +933,10 @@ def _parse_chat_message_content_part(
str_content = cast(str, content)
mm_parser.parse_image(str_content)
return {'type': 'image'} if wrap_dicts else None

if part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content)
return {'type': 'image'} if wrap_dicts else None
if part_type == "audio_url":
str_content = cast(str, content)
mm_parser.parse_audio(str_content)
Expand Down
19 changes: 19 additions & 0 deletions vllm/multimodal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,22 @@ def encode_base64(
data = buffer.getvalue()

return base64.b64encode(data).decode('utf-8')


class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):

def __init__(self) -> None:
super().__init__()

def load_bytes(self, data: bytes) -> torch.Tensor:
buffer = BytesIO(data)
return torch.load(buffer, weights_only=True)

def load_base64(self, media_type: str, data: str) -> torch.Tensor:
return self.load_bytes(base64.b64decode(data))

def load_file(self, filepath: Path) -> torch.Tensor:
return torch.load(filepath)

def encode_base64(self, media: torch.Tensor) -> str:
return base64.b64encode(media.numpy()).decode('utf-8')
14 changes: 13 additions & 1 deletion vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
import numpy.typing as npt
import torch
from PIL import Image

import vllm.envs as envs
Expand All @@ -16,7 +17,7 @@

from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageMediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO

Expand Down Expand Up @@ -245,6 +246,17 @@ async def fetch_video_async(
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)

def fetch_image_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()

return image_embedding_io.load_base64("", data)


global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""
Expand Down