Skip to content

Commit

Permalink
[Misc] Add CustomOp interface for device portability (#5255)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jun 5, 2024
1 parent 974fc9b commit 41ca62c
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 27 deletions.
4 changes: 2 additions & 2 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_act_and_mul(
elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh")
out = layer(x)
ref_out = layer._forward(x)
ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_activation(
x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation()
out = layer(x)
ref_out = layer._forward(x)
ref_out = layer.forward_native(x)
assert torch.allclose(out,
ref_out,
atol=get_default_atol(out),
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_rms_norm(

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_out = layer._forward(x, residual)
ref_out = layer.forward_native(x, residual)
out = layer(x, residual)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
Expand Down
7 changes: 4 additions & 3 deletions tests/kernels/test_pos_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_rotary_embedding(

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results.
assert torch.allclose(out_query,
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_batched_rotary_embedding(

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions,
query,
key,
Expand Down Expand Up @@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(

# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.
Expand Down
60 changes: 60 additions & 0 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch.nn as nn

from vllm.utils import is_cpu, is_hip


class CustomOp(nn.Module):

def __init__(self, *args, **kwargs):
super().__init__()
self._forward_method = self.dispatch_forward()

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError

def forward_cuda(self, *args, **kwargs):
raise NotImplementedError

def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)

def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with CUDA ops.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)

def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)

def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)

def forward_gaudi(self, *args, **kwargs):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)

def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if is_hip():
return self.forward_hip
elif is_cpu():
return self.forward_cpu
else:
return self.forward_cuda
34 changes: 21 additions & 13 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import torch.nn as nn
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs


class SiluAndMul(nn.Module):
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Expand All @@ -23,20 +23,22 @@ class SiluAndMul(nn.Module):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out


class GeluAndMul(nn.Module):
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Expand All @@ -52,12 +54,14 @@ def __init__(self, approximate: str = "none"):
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")

def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Expand All @@ -71,28 +75,32 @@ def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'


class NewGELU(nn.Module):
class NewGELU(CustomOp):

def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_new(out, x)
return out


class FastGELU(nn.Module):
class FastGELU(CustomOp):

def _forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import torch
import torch.nn as nn

from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp


class RMSNorm(nn.Module):
class RMSNorm(CustomOp):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Expand All @@ -23,7 +23,7 @@ def __init__(
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def _forward(
def forward_native(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
Expand All @@ -43,11 +43,13 @@ def _forward(
else:
return x, residual

def forward(
def forward_cuda(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm import _custom_ops as ops

if residual is not None:
ops.fused_add_rms_norm(
x,
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
import torch.nn as nn

from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand All @@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)


class RotaryEmbedding(nn.Module):
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""

def __init__(
Expand Down Expand Up @@ -93,7 +93,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
cache = torch.cat((cos, sin), dim=-1)
return cache

def _forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -138,13 +138,15 @@ def _forward(
key = key.flatten(-2)
return query, key

def forward(
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops

self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
Expand Down

0 comments on commit 41ca62c

Please sign in to comment.