Skip to content

Commit

Permalink
[V1][PP] Cache Intermediate Tensors (vllm-project#13353)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
  • Loading branch information
WoosukKwon authored and panf2333 committed Feb 18, 2025
1 parent d41dbe6 commit c6f7a31
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gc
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast

import numpy as np
import torch
Expand Down Expand Up @@ -149,6 +149,7 @@ def __init__(
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# self.intermediate_tensors # Set after load_model

# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
Expand Down Expand Up @@ -869,7 +870,7 @@ def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> ModelRunnerOutput:
) -> Union[ModelRunnerOutput, torch.Tensor]:
batch_changed = self._update_states(scheduler_output)

if self.is_multimodal_model:
Expand Down Expand Up @@ -919,6 +920,14 @@ def execute_model(
else:
positions = self.positions[:num_input_tokens]

if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
for k, v in self.intermediate_tensors.items()
})

# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
Expand All @@ -931,7 +940,9 @@ def execute_model(
inputs_embeds=inputs_embeds,
)
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
return hidden_states

hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
Expand Down Expand Up @@ -1118,12 +1129,21 @@ def _dummy_run(
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device)

if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if not hasattr(self, "intermediate_tensors"):
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})

with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=input_ids,
Expand Down

0 comments on commit c6f7a31

Please sign in to comment.