Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Use torch.compile for basic custom ops #7110

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
88 changes: 71 additions & 17 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,106 @@
import torch
import torch.nn as nn

import vllm.envs as envs
from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu


class CustomOp(nn.Module):

def __init__(self, *args, **kwargs):
super().__init__()
self._is_compiled = False
self._forward_method = self.dispatch_forward()

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

def forward_native(self, *args, **kwargs):
@staticmethod
def forward_static(*args, **kwargs):
"""PyTorch-native implementation of the forward method.

This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
This method is optional. If implemented, it can be used with
`torch.compile` to generate the custom kernel for the op.
This function is necessary because any method that uses `self`
triggers re-compilation in `torch.compile`.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
"""
raise NotImplementedError

def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.

This method is expected to invoke the forward_static method
with the input arguments and the attributes of `self`.
By default, we assume that forward_static does not need any
attributes of `self`.
"""
return self.forward_static(*args, **kwargs)

def forward_compile(self, *args, **kwargs):
"""Runs a torch-compiled version of forward_native.

This method hooks into the forward_native method and compiles
`forward_static` using `torch.compile` if it has not been compiled yet.
"""
if not self._is_compiled and not envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
self.forward_static = torch.compile( # type: ignore
self.forward_static)
self._is_compiled = True
return self.forward_native(*args, **kwargs)

def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
"""Forward method for NVIDIA GPUs.

By default, we use torch.compile to optimize the op. However, we can
override this method to use a custom CUDA implementation if needed.
"""
return self.forward_compile(*args, **kwargs)

def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
"""Forward method for AMD GPUs.

By default, this method is the same as forward_cuda. However, we can
override this method to use a custom HIP implementation if needed.
"""
return self.forward_cuda(*args, **kwargs)

def forward_xpu(self, *args, **kwargs):
raise NotImplementedError
"""Forward method for Intel XPUs.

By default, we use torch.compile to optimize the op. However, we can
override this method to use a custom XPU implementation if needed.
"""
return self.forward_compile(*args, **kwargs)

def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
"""Forward method for CPUs.

def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
By default, we use torch.compile to optimize the op. However, we can
override this method to use a custom CPU implementation if needed.
"""
if not self._is_compiled and not envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE:
self.forward_static = torch.compile( # type: ignore
self.forward_static,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting dynamic=True explicitly can reduce recompilations, because of the dynamic batchsize. Maybe the cuda is similar.

options={
"fx_graph_cache": True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"fx_graph_cache": True,

This option causes lock contention when using multiprocessing.

Copy link
Contributor

@jon-chuang jon-chuang Aug 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can set a per-process fx_graph_cache

You can set the env var TORCHINDUCTOR_CACHE_DIR

See: https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. I think we can explore caching in a future PR.

"cpp_wrapper": True,
"dce": True,
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
})
self._is_compiled = True
return self.forward_native(*args, **kwargs)

def forward_gaudi(self, *args, **kwargs):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
def forward_tpu(self, *args, **kwargs):
"""Forward method for TPUs.

For TPUs, the whole model is torch-compiled instead of individual ops.
So, we can just use the native implementation.
"""
return self.forward_native(*args, **kwargs)

def forward_hpu(self, *args, **kwargs):
"""Forward method for HPUs."""
raise NotImplementedError("HPU is not supported yet.")

def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
Expand Down
110 changes: 15 additions & 95 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,11 @@ class SiluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@staticmethod
def forward_static(x: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

@bnellnm bnellnm Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this completely eliminate the custom silu_and_mul kernel? If so, should it be removed from csrc?

Ditto for the rest of the custom activation ops.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. I think we can delete most of them, while leaving some (e.g., in csrc/legacy) for potential future use?

d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out


class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
Expand All @@ -63,114 +45,52 @@ def __init__(self, approximate: str = "none"):
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@staticmethod
def forward_static(x: torch.Tensor, approximate: str) -> torch.Tensor:
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_static(x, self.approximate)

def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'


class NewGELU(CustomOp):

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@staticmethod
def forward_static(x: torch.Tensor) -> torch.Tensor:
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_new(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_new(out, x)
return out


class FastGELU(CustomOp):

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@staticmethod
def forward_static(x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out

def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops

out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out


class QuickGELU(CustomOp):

# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@staticmethod
def forward_static(x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_quick(out, x)
return out

# TODO implement forward_xpu for QuickGELU
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@staticmethod
def forward_static(x: torch.Tensor) -> torch.Tensor:
return torch.square(F.relu(x))

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
Expand Down
Loading
Loading