diff --git a/csrc/ops.h b/csrc/ops.h index f737f50c2ec96..c50eb39a3dacc 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -5,6 +5,30 @@ #include "core/scalar_type.hpp" +#include + +torch::Tensor weak_ref_tensor(torch::Tensor& tensor) { + // Ensure tensor is on CUDA + if (!tensor.is_cuda()) { + throw std::runtime_error("Tensor must be on CUDA device"); + } + + // Get the raw data pointer + void* data_ptr = tensor.data_ptr(); + + // Get tensor sizes and strides + std::vector sizes = tensor.sizes().vec(); + std::vector strides = tensor.strides().vec(); + + // Get tensor options (dtype, device) + auto options = tensor.options(); + + // Create a new tensor from the raw data pointer + auto new_tensor = torch::from_blob(data_ptr, sizes, strides, options); + + return new_tensor; +} + void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index e704ff629fd6e..b8185c24d5628 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -18,6 +18,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops + ops.def("weak_ref_tensor(Tensor input) -> Tensor"); + ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); + // Attention ops // Compute the attention between an input query and the cached // keys/values using PagedAttention. diff --git a/vllm/utils.py b/vllm/utils.py index fba9804289b94..1f75de89d0cc2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1479,3 +1479,12 @@ def __iter__(self): def __len__(self): return len(self._factory) + + +def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: + """ + Create a weak reference to a tensor. + The new tensor will share the same data as the original tensor, + but will not keep the original tensor alive. + """ + return torch.ops._C.weak_ref_tensor(tensor) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8b74f06e77be0..4a287e3741d0f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -50,7 +50,7 @@ from vllm.transformers_utils.config import uses_mrope from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, - supports_dynamo) + supports_dynamo, weak_ref_tensor) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -1426,12 +1426,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: dtype=self.model_config.dtype, device=self.device) - # Prepare buffer for outputs. These will be reused for all batch sizes. - # It will be filled after the first graph capture. - hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [ - None - ] * self.parallel_config.pipeline_parallel_size - graph_batch_size = self.max_batchsize_to_capture batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size @@ -1474,12 +1468,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: input_tokens[:batch_size], "positions": input_positions[..., :batch_size], - "hidden_or_intermediate_states": - hidden_or_intermediate_states[ - virtual_engine] # type: ignore - [:batch_size] - if hidden_or_intermediate_states[virtual_engine] - is not None else None, "intermediate_inputs": intermediate_inputs[:batch_size] if intermediate_inputs is not None else None, @@ -1762,15 +1750,13 @@ def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, - hidden_or_intermediate_states: Optional[Union[IntermediateTensors, - torch.Tensor]], intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool: Optional[Tuple[int, int]], stream: torch.cuda.Stream, **kwargs, - ) -> Union[torch.Tensor, IntermediateTensors]: + ): assert self._graph is None # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the @@ -1799,20 +1785,21 @@ def capture( intermediate_tensors=intermediate_inputs, **kwargs, ) - if hidden_or_intermediate_states is not None: - if get_pp_group().is_last_rank: - hidden_or_intermediate_states.copy_( - output_hidden_or_intermediate_states) - else: - for key in hidden_or_intermediate_states.tensors: - hidden_or_intermediate_states[key].copy_( - output_hidden_or_intermediate_states[key]) - else: - hidden_or_intermediate_states = ( + + if isinstance(output_hidden_or_intermediate_states, torch.Tensor): + hidden_or_intermediate_states = weak_ref_tensor( output_hidden_or_intermediate_states) + elif isinstance(output_hidden_or_intermediate_states, + IntermediateTensors): + hidden_or_intermediate_states = IntermediateTensors( + tensors={ + key: weak_ref_tensor(value) + for key, value in + output_hidden_or_intermediate_states.tensors.items() + }) del output_hidden_or_intermediate_states - # make sure `output_hidden_states` is deleted + # make sure `output_hidden_or_intermediate_states` is deleted # in the graph's memory pool gc.collect() torch.cuda.synchronize() @@ -1837,7 +1824,6 @@ def capture( } else: self.output_buffers = hidden_or_intermediate_states - return hidden_or_intermediate_states def forward( self,