Skip to content

[Docs] Enhance SupportsMultiModal interface documentation #19701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 107 additions & 16 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,35 @@

@runtime_checkable
class SupportsMultiModal(Protocol):
"""The interface required for all multi-modal models."""
"""
The interface required for all multi-modal models in vLLM.

This protocol defines the contract for models that process both text and
multimodal inputs (images, audio, video, etc.). It establishes a standard
workflow: parse multimodal inputs → generate embeddings → merge with text.

The interface is used by:
- V1 model runners for encoder execution and caching
- Speculative decoding for multimodal compatibility
- TPU/GPU optimizations and scheduling

Implementation Pattern:
1. Inherit from SupportsMultiModal in your model class
2. Implement all three required methods
3. Use utilities like merge_multimodal_embeddings() for embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use cross-references so they link to the corresponding location in API docs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. I'll look into how to do that properly.

integration
"""

supports_multimodal: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs.

Used by vLLM's inference engine to identify multimodal models for:
- Encoder cache allocation
- Multimodal-specific scheduling
- LoRA compatibility warnings
- Speculative decoding setup

Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
Expand All @@ -47,25 +70,59 @@ class SupportsMultiModal(Protocol):
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
Process multimodal inputs and return embeddings for text integration.

This method is the core of multimodal processing, called by V1 model
runners during encoder execution. It should:
1. Parse and validate multimodal inputs from kwargs
2. Process them through encoders/projectors
3. Return embeddings ready for merging with text

Common Implementation Pattern:
modal_input = self._parse_and_validate_X_input(**kwargs)
if modal_input is None:
return []
return self._process_X_input(modal_input)

Args:
**kwargs: Multimodal inputs (pixel_values, audio_features, etc.)
Varies by model type and supported modalities.

Note:
The returned multimodal embeddings must be in the same order as
the appearances of their corresponding multimodal data item in the
input prompt.
Returns:
MultiModalEmbeddings: Embeddings for text integration. Returns an
empty list/tuple if no multimodal inputs are present.
Embeddings must be ordered to match their placeholder token
positions in the input prompt.

Used by:
- V1 GPU/TPU model runners for batch encoder execution
- Encoder cache manager for caching processed embeddings
- Profiling runs to validate encoder functionality
"""
...

def get_language_model(self) -> torch.nn.Module:
"""
Returns the underlying language model used for text generation.
Returns the underlying language model component for text generation.

This method provides access to the core text processing module,
typically the transformer layers that generate text from the merged
text+multimodal embeddings.

This is typically the `torch.nn.Module` instance responsible for
processing the merged multimodal embeddings and producing hidden states
Common Implementation:
return self.language_model # Most models
return self.model.decoder # Encoder-decoder (e.g., Whisper)
return self.model # Unified architecture (e.g., Molmo)

Returns:
torch.nn.Module: The core language model component.
torch.nn.Module: The language model component responsible for
processing merged embeddings and generating text.

Used by:
- EAGLE speculative decoding for weight sharing with draft models
- TPU model compilation and graph optimization
- LoRA integration (LoRA is applied only to language model layers)
- Model introspection and debugging
"""
...

Expand All @@ -87,9 +144,43 @@ def get_input_embeddings(
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
Convert input tokens to embeddings and merge with multimodal embeddings.

This method combines text and multimodal inputs into a unified embedding
representation for the language model. It replaces placeholder tokens
(e.g., <image>, <audio>) in the text with corresponding
multimodal embeddings.

Standard Implementation Pattern:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index # or audio_token_index, etc.
)
return inputs_embeds

Args:
input_ids: Text token IDs including placeholder tokens for
multimodal inputs
multimodal_embeddings: Processed multimodal embeddings from
get_multimodal_embeddings(), or None if no
multimodal data

Returns:
Tensor: Combined embeddings ready for language model processing.
Shape: [batch_size, sequence_length, hidden_size]

Used by:
- V1 model runners during embedding conversion phase
- Models that use embeddings instead of raw token IDs
- TPU model compilation for different input shapes

Notes:
- Multimodal embeddings must be ordered to match placeholder
positions
- Uses merge_multimodal_embeddings() utility for token replacement
- Critical for proper text-multimodal alignment in the model
"""
...

Expand Down Expand Up @@ -392,13 +483,13 @@ def is_attention_free(
@runtime_checkable
class IsHybrid(Protocol):
"""The interface required for all models like Jamba that have both
attention and mamba blocks, indicates that
attention and mamba blocks, indicates that
hf_config has 'layers_block_type'"""

is_hybrid: ClassVar[Literal[True]] = True
"""
A flag that indicates this model has both mamba and attention blocks
, also indicates that the model's hf_config has
, also indicates that the model's hf_config has
'layers_block_type' """


Expand Down