Skip to content

Commit

Permalink
[Core][CUDA Graph] add output buffer for cudagraph (vllm-project#5074)
Browse files Browse the repository at this point in the history
[Core][CUDA Graph] add output buffer for cudagraph to reduce memory footprint (vllm-project#5074)
  • Loading branch information
youkaichao authored and jimpang committed Jul 24, 2024
1 parent 77891b8 commit bf8ef0b
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import time
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -894,6 +895,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()

# Prepare buffer for outputs. These will be reused for all batch sizes.
# It will be filled after the first graph capture.
hidden_states: Optional[torch.Tensor] = None

graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
Expand Down Expand Up @@ -930,9 +935,11 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
self.set_active_loras(set(), lora_mapping)

graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
hidden_states = graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
hidden_states[:batch_size]
if hidden_states is not None else None,
kv_caches,
attn_metadata,
memory_pool=self.graph_memory_pool,
Expand Down Expand Up @@ -969,12 +976,13 @@ def capture(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs,
) -> None:
) -> torch.Tensor:
assert self._graph is None
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
Expand All @@ -993,13 +1001,21 @@ def capture(
# Capture the graph.
self._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
hidden_states = self.model(
output_hidden_states = self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
if hidden_states is not None:
hidden_states.copy_(output_hidden_states)
else:
hidden_states = output_hidden_states
del output_hidden_states
# make sure `output_hidden_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize()

# Save the input and output buffers.
Expand All @@ -1012,7 +1028,7 @@ def capture(
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
return hidden_states

def forward(
self,
Expand Down

0 comments on commit bf8ef0b

Please sign in to comment.