From e00c094f15e79c5a113fdf975df1ee9018cb65b3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 10 Oct 2024 15:54:23 -0700 Subject: [PATCH] [torch.compile] generic decorators (#9258) --- vllm/compilation/decorators.py | 88 ++++++++++++++++++---------- vllm/model_executor/models/gemma2.py | 10 +++- vllm/model_executor/models/llama.py | 10 +++- 3 files changed, 74 insertions(+), 34 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index b790e5550adb7..655c4c4430179 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,20 +1,54 @@ -from typing import List, Optional, Union +import inspect +from typing import Dict, List, Union import torch import vllm.envs as envs -from vllm.attention import AttentionMetadata from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo -def support_compile_llama_style(cls: type): +def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]): + """ + A decorator to add support for compiling the forward method of a class. + + `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic + dimensions of the argument. The dynamic dimensions can be either a single + integer or a list of integers. + + Depending on the value of arguments: + + - if it is a single integer, the corresponding dimension of the argument + will be marked as dynamic. + - if it is `None`, ignored. + - if it is `IntermediateTensors`, all the tensors in the intermediate + tensors will be marked as dynamic. + - otherwise, it will raise an error. + + NOTE: if an argument is `None`, it should always be passed as `None` during + the lifetime of the model, otherwise, it cannot be captured as a single + computation graph. + """ + + def cls_decorator_helper(cls: type): + # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` + # to avoid too much indentation for `_support_torch_compile`` + sig = inspect.signature(cls.forward) + for k in dynamic_arg_dims: + if k not in sig.parameters: + raise ValueError( + f"Argument {k} not found in the forward method of {cls}") + return _support_torch_compile(cls, dynamic_arg_dims) + + return cls_decorator_helper + + +def _support_torch_compile(cls: type, + dynamic_arg_dims: Dict[str, Union[int, List[int]]]): """ A decorator to add support for compiling the forward method of a class. - If a module's **forward signature** is compatible with llama, this - decorator can be used to enable the compilation of the forward method. """ # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner @@ -37,48 +71,42 @@ def __init__(self, *args, **kwargs): cls.__init__ = __init__ - def __call__( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: + def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. if torch.compiler.is_compiling(): - return self.forward(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) + return self.forward(*args, **kwargs) # the first compilation needs to have dynamic shapes marked if len(self.compiled_codes) < 1: - if input_ids is not None: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(positions, 0) - if inputs_embeds is not None: - torch._dynamo.mark_dynamic(inputs_embeds, 0) - if intermediate_tensors is not None: - for tensors in intermediate_tensors.tensors.values(): - torch._dynamo.mark_dynamic(tensors, 0) + sig = inspect.signature(self.__class__.forward) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + if isinstance(arg, torch.Tensor): + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}.") # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, # with the overhead of guard evaluation and recompilation. if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: - return self.compiled_callable(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + return self.compiled_callable(*args, **kwargs) # usually, capturing the model once is enough, and then we can # dispatch to the compiled code directly, without going through # the Dynamo guard mechanism. with self.dispatch_to_code(0): - model_output = self.forward(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) + model_output = self.forward(*args, **kwargs) return model_output cls.__call__ = __call__ diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index edc71435b551f..bcb03ef55ef94 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -21,7 +21,7 @@ from transformers import Gemma2Config from vllm.attention import Attention, AttentionMetadata -from vllm.compilation.decorators import support_compile_llama_style +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger @@ -239,7 +239,13 @@ def forward( return hidden_states, residual -@support_compile_llama_style +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": 0, + "inputs_embeds": 0, + "intermediate_tensors": 0, + }) class Gemma2Model(nn.Module): def __init__( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3f17e9004c30f..ad5cfcc44022f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,7 +28,7 @@ from transformers import LlamaConfig from vllm.attention import Attention, AttentionMetadata -from vllm.compilation.decorators import support_compile_llama_style +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -266,7 +266,13 @@ def forward( return hidden_states, residual -@support_compile_llama_style +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": 0, + "inputs_embeds": 0, + "intermediate_tensors": 0, + }) class LlamaModel(nn.Module): def __init__(