-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[V1][TPU] TPU multimodal model support for ragged attention #14158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,14 +15,18 @@ | |
from vllm.attention.layer import Attention | ||
from vllm.config import VllmConfig | ||
from vllm.forward_context import get_forward_context, set_forward_context | ||
from vllm.inputs import INPUT_REGISTRY | ||
from vllm.logger import init_logger | ||
from vllm.model_executor.model_loader import get_model | ||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs | ||
from vllm.multimodal.utils import group_mm_inputs_by_modality | ||
from vllm.sampling_params import SamplingType | ||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available | ||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, | ||
NUM_QUERIES_PER_BLOCK, | ||
PallasAttentionBackend, | ||
PallasMetadata) | ||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget | ||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, | ||
KVCacheSpec) | ||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput | ||
|
@@ -72,8 +76,10 @@ def __init__( | |
self.block_size = cache_config.block_size | ||
self.max_model_len = model_config.max_model_len | ||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) | ||
self.max_num_tokens = scheduler_config.max_num_batched_tokens | ||
self.max_num_reqs = scheduler_config.max_num_seqs | ||
self.max_num_tokens = _get_padded_number( | ||
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK) | ||
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs, | ||
NUM_QUERIES_PER_BLOCK) | ||
|
||
# Model-related. | ||
self.num_attn_layers = model_config.get_num_layers_by_block_type( | ||
|
@@ -84,25 +90,38 @@ def __init__( | |
self.head_size = model_config.get_head_size() | ||
self.hidden_size = model_config.get_hidden_size() | ||
|
||
# Multi-modal data support | ||
self.input_registry = INPUT_REGISTRY | ||
self.mm_registry = MULTIMODAL_REGISTRY | ||
self.uses_mrope = model_config.uses_mrope | ||
# TODO: Support M-RoPE (e.g, Qwen2-VL) | ||
assert not self.uses_mrope, "TPU does not support M-RoPE yet." | ||
|
||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( | ||
model_config=model_config, | ||
scheduler_config=scheduler_config, | ||
) | ||
self.max_num_encoder_input_tokens = encoder_compute_budget | ||
self.encoder_cache_size = encoder_cache_size | ||
|
||
# Lazy initialization | ||
# self.model: nn.Module # Set after load_model | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think we can just delete the commented line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just wanted to line this up with gpu_model_runner.py, so this was copied |
||
self.kv_caches: list[torch.Tensor] = [] | ||
# req_id -> (input_id -> encoder_output) | ||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} | ||
|
||
# Request states. | ||
self.requests: dict[str, CachedRequestState] = {} | ||
# Persistent batch. | ||
self.input_batch = InputBatch( | ||
max_num_reqs=self.max_num_reqs, | ||
max_model_len=self.max_model_len, | ||
max_num_blocks_per_req=self.max_num_blocks_per_req, | ||
device=self.device, | ||
pin_memory=self.pin_memory, | ||
vocab_size=self.model_config.get_vocab_size(), | ||
vocab_size=model_config.get_vocab_size(), | ||
) | ||
|
||
# Request states. | ||
self.requests: dict[str, CachedRequestState] = {} | ||
|
||
# req_id -> (input_id -> encoder_output) | ||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} | ||
|
||
# KV caches for forward pass | ||
self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = [] | ||
|
||
# Cached torch/numpy tensor | ||
# The pytorch tensor and numpy array share the same buffer. | ||
# Sometimes the numpy op is faster so we create both. | ||
|
@@ -164,6 +183,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: | |
# Remove finished requests from the cached states. | ||
for req_id in scheduler_output.finished_req_ids: | ||
self.requests.pop(req_id, None) | ||
self.encoder_cache.pop(req_id, None) | ||
|
||
# Remove the finished requests from the persistent batch. | ||
# NOTE(woosuk): There could be an edge case where finished_req_ids and | ||
|
@@ -177,6 +197,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: | |
if req_index is not None: | ||
removed_req_indices.append(req_index) | ||
|
||
# Free the cached encoder outputs. | ||
for req_id, input_id in scheduler_output.free_encoder_input_ids: | ||
encoder_outputs = self.encoder_cache.get(req_id) | ||
if encoder_outputs is not None: | ||
encoder_outputs.pop(input_id, None) | ||
if not encoder_outputs: | ||
self.encoder_cache.pop(req_id, None) | ||
|
||
# Remove the unscheduled requests from the persistent batch. | ||
# NOTE(woosuk): The unscheduled requests are either preempted requests | ||
# or running requests that are not scheduled in this step. We remove | ||
|
@@ -426,6 +454,92 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): | |
logits_indices = query_start_loc[1:] - 1 | ||
return attn_metadata, logits_indices | ||
|
||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"): | ||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs | ||
if not scheduled_encoder_inputs: | ||
return | ||
|
||
# Batch the multi-modal inputs. | ||
mm_inputs: list[MultiModalKwargs] = [] | ||
req_input_ids: list[tuple[str, int]] = [] | ||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): | ||
req_state = self.requests[req_id] | ||
for input_id in encoder_input_ids: | ||
mm_inputs.append(req_state.mm_inputs[input_id]) | ||
req_input_ids.append((req_id, input_id)) | ||
|
||
# Batch mm inputs as much as we can: if a request in the batch has | ||
# multiple modalities or a different modality than the previous one, | ||
# we process it separately to preserve item order. | ||
# FIXME(ywang96): This is a hacky way to deal with multiple modalities | ||
# in the same batch while still being able to benefit from batching | ||
# multimodal inputs. The proper solution should be reordering the | ||
# encoder outputs. | ||
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs) | ||
|
||
encoder_outputs = [] | ||
for grouped_mm_inputs in grouped_mm_inputs_list: | ||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs) | ||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, | ||
device=self.device) | ||
|
||
# Run the encoder. | ||
# `curr_group_outputs` is either of the following: | ||
# 1. A tensor of shape (num_items, feature_size, hidden_size) | ||
# in case feature_size is fixed across all multimodal items. | ||
# 2. A list or tuple (length: num_items) of tensors, each of shape | ||
# (feature_size, hidden_size) in case the feature size is dynamic | ||
# depending on the input multimodal items. | ||
curr_group_outputs = self.model.get_multimodal_embeddings( | ||
**batched_mm_inputs) | ||
|
||
for output in curr_group_outputs: | ||
encoder_outputs.append(output) | ||
|
||
# Cache the encoder outputs. | ||
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): | ||
if req_id not in self.encoder_cache: | ||
self.encoder_cache[req_id] = {} | ||
self.encoder_cache[req_id][input_id] = output | ||
|
||
def _gather_encoder_outputs( | ||
self, | ||
scheduler_output: "SchedulerOutput", | ||
) -> list[torch.Tensor]: | ||
encoder_outputs: list[torch.Tensor] = [] | ||
for req_id in self.input_batch.req_ids: | ||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ | ||
req_id] | ||
req_state = self.requests[req_id] | ||
num_computed_tokens = req_state.num_computed_tokens | ||
mm_positions = req_state.mm_positions | ||
for i, pos_info in enumerate(mm_positions): | ||
start_pos = pos_info["offset"] | ||
num_encoder_tokens = pos_info["length"] | ||
|
||
# The encoder output is needed if the two ranges overlap: | ||
# [num_computed_tokens, | ||
# num_computed_tokens + num_scheduled_tokens) and | ||
# [start_pos, start_pos + num_encoder_tokens) | ||
if start_pos >= num_computed_tokens + num_scheduled_tokens: | ||
# The encoder output is not needed in this step. | ||
break | ||
if start_pos + num_encoder_tokens <= num_computed_tokens: | ||
# The encoder output is already processed and stored | ||
# in the decoder's KV cache. | ||
continue | ||
|
||
start_idx = max(num_computed_tokens - start_pos, 0) | ||
end_idx = min( | ||
num_computed_tokens - start_pos + num_scheduled_tokens, | ||
num_encoder_tokens) | ||
assert start_idx < end_idx | ||
assert req_id in self.encoder_cache | ||
assert i in self.encoder_cache[req_id] | ||
encoder_output = self.encoder_cache[req_id][i] | ||
encoder_outputs.append(encoder_output[start_idx:end_idx]) | ||
return encoder_outputs | ||
|
||
@torch.no_grad() | ||
def execute_model( | ||
self, | ||
|
@@ -434,16 +548,42 @@ def execute_model( | |
# Update cached state | ||
self._update_states(scheduler_output) | ||
|
||
if self.is_multimodal_model: | ||
# Run the multimodal encoder if any. | ||
self._execute_encoder(scheduler_output) | ||
encoder_outputs = self._gather_encoder_outputs(scheduler_output) | ||
else: | ||
encoder_outputs = [] | ||
|
||
# Prepare inputs | ||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) | ||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens | ||
|
||
if self.is_multimodal_model: | ||
# NOTE(woosuk): To unify token ids and soft tokens (vision | ||
# embeddings), we always use embeddings (rather than token ids) | ||
# as input to the multimodal model, even when the input is text. | ||
if encoder_outputs: | ||
inputs_embeds = self.model.get_input_embeddings( | ||
self.input_ids, encoder_outputs) | ||
else: | ||
inputs_embeds = self.model.get_input_embeddings(self.input_ids) | ||
input_ids = None | ||
else: | ||
# For text-only models, we use token ids as input. | ||
# While it is possible to use embeddings as input just like the | ||
# multimodal models, it is not desirable for performance since | ||
# then the embedding layer is not included in the CUDA graph. | ||
input_ids = self.input_ids | ||
inputs_embeds = None | ||
|
||
# Run the decoder | ||
with set_forward_context(attn_metadata, self.vllm_config): | ||
hidden_states = self.model( | ||
token_ids=self.input_ids, | ||
position_ids=self.position_ids, | ||
input_ids=input_ids, | ||
positions=self.position_ids, | ||
kv_caches=self.kv_caches, | ||
inputs_embeds=inputs_embeds, | ||
) | ||
hidden_states = hidden_states[:total_num_scheduled_tokens] | ||
num_reqs = self.input_batch.num_reqs | ||
|
@@ -538,14 +678,21 @@ def load_model(self) -> None: | |
fullgraph=True, | ||
dynamic=False) | ||
|
||
def dummy_run( | ||
def _dummy_run( | ||
self, | ||
kv_caches, | ||
num_tokens: int, | ||
) -> None: | ||
input_ids = torch.zeros(num_tokens, | ||
dtype=torch.int32, | ||
device=self.device) | ||
if self.is_multimodal_model: | ||
input_ids = None | ||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size), | ||
dtype=self.dtype, | ||
device=self.device) | ||
else: | ||
input_ids = torch.zeros((num_tokens), | ||
dtype=torch.int32, | ||
device=self.device) | ||
inputs_embeds = None | ||
position_ids = torch.zeros(num_tokens, | ||
dtype=torch.int32, | ||
device=self.device) | ||
|
@@ -571,7 +718,10 @@ def dummy_run( | |
num_seqs=num_tokens, | ||
) | ||
|
||
torch._dynamo.mark_dynamic(input_ids, 0) | ||
if self.is_multimodal_model: | ||
torch._dynamo.mark_dynamic(inputs_embeds, 0) | ||
else: | ||
torch._dynamo.mark_dynamic(input_ids, 0) | ||
torch._dynamo.mark_dynamic(position_ids, 0) | ||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) | ||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) | ||
|
@@ -580,7 +730,12 @@ def dummy_run( | |
|
||
with set_forward_context(attn_metadata, self.vllm_config, 0): | ||
assert self.model is not None | ||
self.model(input_ids, position_ids, kv_caches) | ||
self.model( | ||
input_ids=input_ids, | ||
positions=position_ids, | ||
kv_caches=kv_caches, | ||
inputs_embeds=inputs_embeds, | ||
) | ||
|
||
def capture_model(self) -> None: | ||
"""Compile the model.""" | ||
|
@@ -590,11 +745,11 @@ def capture_model(self) -> None: | |
start = time.perf_counter() | ||
num_tokens = 16 | ||
while True: | ||
self.dummy_run(self.kv_caches, num_tokens) | ||
self._dummy_run(self.kv_caches, num_tokens) | ||
logger.info(" -- num_tokens: %d", num_tokens) | ||
xm.mark_step() | ||
xm.wait_device_ops() | ||
if num_tokens >= self.scheduler_config.max_num_batched_tokens: | ||
if num_tokens >= self.max_num_tokens: | ||
break | ||
num_tokens *= 2 | ||
end = time.perf_counter() | ||
|
@@ -647,17 +802,20 @@ def __init__(self, model: nn.Module): | |
|
||
def forward( | ||
self, | ||
token_ids: torch.Tensor, | ||
position_ids: torch.Tensor, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]], | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
"""Executes the forward pass of the model and samples the next token. | ||
|
||
Args: | ||
token_ids: The input token IDs of shape [num_tokens]. | ||
position_ids: The input position IDs of shape [num_tokens]. | ||
input_ids: The input token IDs of shape [num_tokens]. | ||
positions: The input position IDs of shape [num_tokens]. | ||
kv_caches: The key and value caches. They can be None during the | ||
memory profiling at initialization. | ||
inputs_embeds: The input embeddings of shape [num_tokens, | ||
hidden_size]. It is used for multimodal models. | ||
""" | ||
# Skip this in memory profiling at initialization. | ||
if kv_caches[0][0].numel() > 0: | ||
|
@@ -684,9 +842,9 @@ def forward( | |
|
||
assert self.model is not None | ||
hidden_states = self.model( | ||
token_ids, | ||
position_ids, | ||
kv_caches, | ||
input_ids=input_ids, | ||
positions=positions, | ||
inputs_embeds=inputs_embeds, | ||
) | ||
|
||
return hidden_states | ||
|
@@ -699,6 +857,12 @@ def compute_logits( | |
logits = self.model.compute_logits(hidden_states, sampling_metadata) | ||
return logits | ||
|
||
def get_multimodal_embeddings(self, *args, **kwargs): | ||
return self.model.get_multimodal_embeddings(*args, **kwargs) | ||
|
||
def get_input_embeddings(self, *args, **kwargs): | ||
return self.model.get_input_embeddings(*args, **kwargs) | ||
|
||
|
||
def _get_padded_number(n: int, multiple: int) -> int: | ||
return ((n + multiple - 1) // multiple) * multiple |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why padding was necessary here? I would expect prepare_inputs to pad the tensors as needed, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue I ran into was the user defines
max_num_seqs
and it doesn't have to be padded. So for instance the user can setmax_num_seqs=1
and will will still compile a minimum padded batch of 16. Since the unpadded number was sent toInputBatch
, we were not allocating enough padded space in our input buffers.