Skip to content

[Misc] Use torch.compile for basic custom ops #7110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 16 commits into from
Closed
68 changes: 52 additions & 16 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import torch
import torch.nn as nn

from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu

# Set this flag to avoid the re-compilation due to `self`.
# NOTE(woosuk): This only works for PyTorch 2.4+.
# In PyTorch 2.5, we will not need this flag.
torch._dynamo.config.inline_inbuilt_nn_modules = True


class CustomOp(nn.Module):

def __init__(self, *args, **kwargs):
super().__init__()
self._is_compiled = False
self._forward_method = self.dispatch_forward()

def forward(self, *args, **kwargs):
Expand All @@ -16,38 +23,67 @@ def forward(self, *args, **kwargs):
def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.

This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
This method is optional. If implemented, it can be used with
`torch.compile` to generate the custom kernel for the op.
"""
raise NotImplementedError

def forward_compile(self, *args, **kwargs):
"""Runs a torch-compiled version of forward_native."""
if torch._utils.is_compiling():
return self.forward_native(*args, **kwargs)

if not self._is_compiled:
options = {"cpp_wrapper": True} if is_cpu() else None
self._forward_compiled = torch.compile(self.forward_native,
options=options)
self._is_compiled = True
return self._forward_compiled(*args, **kwargs)

def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
"""Forward method for NVIDIA GPUs.

By default, we use torch.compile to optimize the op. However, we can
override this method to use a custom CUDA implementation if needed.
"""
return self.forward_compile(*args, **kwargs)

def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
"""Forward method for AMD GPUs.

By default, this method is the same as forward_cuda. However, we can
override this method to use a custom HIP implementation if needed.
"""
return self.forward_cuda(*args, **kwargs)

def forward_xpu(self, *args, **kwargs):
raise NotImplementedError
"""Forward method for Intel XPUs.

By default, we use torch.compile to optimize the op. However, we can
override this method to use a custom XPU implementation if needed.
"""
return self.forward_compile(*args, **kwargs)

def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
"""Forward method for CPUs.

By default, we use torch.compile to optimize the op. However, we can
override this method to use a custom CPU implementation if needed.
"""
return self.forward_compile(*args, **kwargs)

def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
"""Forward method for TPUs.

def forward_gaudi(self, *args, **kwargs):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
For TPUs, the whole model is torch-compiled instead of individual ops.
So, we can just use the native implementation.
"""
return self.forward_native(*args, **kwargs)

def forward_hpu(self, *args, **kwargs):
"""Forward method for HPUs."""
raise NotImplementedError("HPU is not supported yet.")

def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
Expand Down
45 changes: 2 additions & 43 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.silu_and_mul(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out


class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
Expand Down Expand Up @@ -80,18 +71,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_tanh_and_mul(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out

def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'

Expand All @@ -111,13 +90,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_new(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_new(out, x)
return out


class FastGELU(CustomOp):

Expand All @@ -133,13 +105,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_fast(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out


class QuickGELU(CustomOp):

Expand All @@ -155,9 +120,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
ops.gelu_quick(out, x)
return out

# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


class ReLUSquaredActivation(CustomOp):
"""
Expand All @@ -168,11 +130,8 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return torch.square(F.relu(x))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)


class ScaledActivation(nn.Module):
class ScaledActivation(CustomOp):
"""An activation function with post-scale parameters.

This is used for some quantization methods like AWQ.
Expand Down Expand Up @@ -200,7 +159,7 @@ def __init__(
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return self.act(x) / self.scales

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
Expand Down
32 changes: 0 additions & 32 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,6 @@ def forward_cuda(
)
return out

def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm._ipex_ops import ipex_ops as ops

if residual is not None:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out

def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
Expand Down Expand Up @@ -133,11 +109,3 @@ def forward_native(
x = x * (1.0 + self.weight.float())
x = x.to(orig_dtype)
return x if residual is None else (x, residual)

def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
return self.forward_native(x, residual)
22 changes: 16 additions & 6 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import torch.nn as nn

from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -89,8 +88,6 @@ def __init__(
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)

self.use_native2 = current_platform.is_tpu() and is_neox_style

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
Expand Down Expand Up @@ -222,6 +219,18 @@ def forward_cuda(
self.cos_sin_cache, self.is_neox_style)
return query, key

def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.is_neox_style:
return self.forward_native2(positions, query, key, offsets)
else:
return self.forward_native(positions, query, key, offsets)

def forward_xpu(
self,
positions: torch.Tensor,
Expand Down Expand Up @@ -252,9 +261,10 @@ def forward_tpu(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_fn = (self.forward_native2
if self.use_native2 else self.forward_native)
return forward_fn(positions, query, key, offsets)
if self.is_neox_style:
return self.forward_native2(positions, query, key, offsets)
else:
return self.forward_native(positions, query, key, offsets)

def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
Expand Down
28 changes: 11 additions & 17 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -48,19 +49,7 @@
from vllm.sequence import IntermediateTensors, SamplerOutput


@torch.compile
def layer_norm_func(hidden_states, weight, variance_epsilon):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance +
variance_epsilon)
hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype)


class LayerNorm(nn.Module):
class LayerNorm(CustomOp):

def __init__(self, param_shape=None, eps=1e-5):
super().__init__()
Expand All @@ -69,10 +58,15 @@ def __init__(self, param_shape=None, eps=1e-5):
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})

def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
def forward_native(self, hidden_states, residuals=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states -
mean) * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype), residuals


# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
Expand Down
Loading