Skip to content

[V1][TPU] TPU multimodal model support for ragged attention #14158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 193 additions & 29 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
Expand Down Expand Up @@ -72,8 +76,10 @@ def __init__(
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
self.max_num_tokens = _get_padded_number(
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why padding was necessary here? I would expect prepare_inputs to pad the tensors as needed, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue I ran into was the user defines max_num_seqs and it doesn't have to be padded. So for instance the user can set max_num_seqs=1 and will will still compile a minimum padded batch of 16. Since the unpadded number was sent to InputBatch, we were not allocating enough padded space in our input buffers.

NUM_QUERIES_PER_BLOCK)

# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
Expand All @@ -84,25 +90,38 @@ def __init__(
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()

# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
# TODO: Support M-RoPE (e.g, Qwen2-VL)
assert not self.uses_mrope, "TPU does not support M-RoPE yet."

encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size

# Lazy initialization
# self.model: nn.Module # Set after load_model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can just delete the commented line

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to line this up with gpu_model_runner.py, so this was copied

self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
vocab_size=model_config.get_vocab_size(),
)

# Request states.
self.requests: dict[str, CachedRequestState] = {}

# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}

# KV caches for forward pass
self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = []

# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
# Sometimes the numpy op is faster so we create both.
Expand Down Expand Up @@ -164,6 +183,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)

# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
Expand All @@ -177,6 +197,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
if req_index is not None:
removed_req_indices.append(req_index)

# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)

# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
# or running requests that are not scheduled in this step. We remove
Expand Down Expand Up @@ -426,6 +454,92 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices

def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return

# Batch the multi-modal inputs.
mm_inputs: list[MultiModalKwargs] = []
req_input_ids: list[tuple[str, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))

# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
# we process it separately to preserve item order.
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)

encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)

# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
# in case feature_size is fixed across all multimodal items.
# 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)

for output in curr_group_outputs:
encoder_outputs.append(output)

# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output

def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> list[torch.Tensor]:
encoder_outputs: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]

# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue

start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs

@torch.no_grad()
def execute_model(
self,
Expand All @@ -434,16 +548,42 @@ def execute_model(
# Update cached state
self._update_states(scheduler_output)

if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
else:
encoder_outputs = []

# Prepare inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens

if self.is_multimodal_model:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
self.input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
input_ids = None
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids
inputs_embeds = None

# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
token_ids=self.input_ids,
position_ids=self.position_ids,
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:total_num_scheduled_tokens]
num_reqs = self.input_batch.num_reqs
Expand Down Expand Up @@ -538,14 +678,21 @@ def load_model(self) -> None:
fullgraph=True,
dynamic=False)

def dummy_run(
def _dummy_run(
self,
kv_caches,
num_tokens: int,
) -> None:
input_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
else:
input_ids = torch.zeros((num_tokens),
dtype=torch.int32,
device=self.device)
inputs_embeds = None
position_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
Expand All @@ -571,7 +718,10 @@ def dummy_run(
num_seqs=num_tokens,
)

torch._dynamo.mark_dynamic(input_ids, 0)
if self.is_multimodal_model:
torch._dynamo.mark_dynamic(inputs_embeds, 0)
else:
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
Expand All @@ -580,7 +730,12 @@ def dummy_run(

with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(input_ids, position_ids, kv_caches)
self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)

def capture_model(self) -> None:
"""Compile the model."""
Expand All @@ -590,11 +745,11 @@ def capture_model(self) -> None:
start = time.perf_counter()
num_tokens = 16
while True:
self.dummy_run(self.kv_caches, num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
xm.mark_step()
xm.wait_device_ops()
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
end = time.perf_counter()
Expand Down Expand Up @@ -647,17 +802,20 @@ def __init__(self, model: nn.Module):

def forward(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.

Args:
token_ids: The input token IDs of shape [num_tokens].
position_ids: The input position IDs of shape [num_tokens].
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
# Skip this in memory profiling at initialization.
if kv_caches[0][0].numel() > 0:
Expand All @@ -684,9 +842,9 @@ def forward(

assert self.model is not None
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)

return hidden_states
Expand All @@ -699,6 +857,12 @@ def compute_logits(
logits = self.model.compute_logits(hidden_states, sampling_metadata)
return logits

def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)

def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)


def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple
2 changes: 1 addition & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def determine_available_memory(self) -> int:
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)

self.model_runner.dummy_run(
self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
Expand Down