diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index a624c4ca9ee62..a4b9f91c7688b 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -44,7 +44,7 @@ def test_act_and_mul( elif activation == "gelu_tanh": layer = GeluAndMul(approximate="tanh") out = layer(x) - ref_out = layer._forward(x) + ref_out = layer.forward_native(x) # The SiLU and GELU implementations are equivalent to the native PyTorch # implementations, so we can do exact comparison. assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0) @@ -72,7 +72,7 @@ def test_activation( x = torch.randn(num_tokens, d, dtype=dtype) layer = activation() out = layer(x) - ref_out = layer._forward(x) + ref_out = layer.forward_native(x) assert torch.allclose(out, ref_out, atol=get_default_atol(out), diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 210d59e4f32fa..a635e6c12c594 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -42,7 +42,7 @@ def test_rms_norm( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. - ref_out = layer._forward(x, residual) + ref_out = layer.forward_native(x, residual) out = layer(x, residual) # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger # numerical errors than other operators because they involve reductions. diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index fbabc02bf9a9d..e564e325112a6 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -64,7 +64,7 @@ def test_rotary_embedding( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. - ref_query, ref_key = rope._forward(positions, query, key) + ref_query, ref_key = rope.forward_native(positions, query, key) out_query, out_key = rope.forward(positions, query, key) # Compare the results. assert torch.allclose(out_query, @@ -121,7 +121,7 @@ def test_batched_rotary_embedding( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. - ref_query, ref_key = rope._forward(positions, query, key) + ref_query, ref_key = rope.forward_native(positions, query, key) out_query, out_key = rope.forward(positions, query, key, @@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora( # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. - ref_query, ref_key = rope._forward(positions, query, key, query_offsets) + ref_query, ref_key = rope.forward_native(positions, query, key, + query_offsets) out_query, out_key = rope.forward(positions, query, key, query_offsets.flatten()) # Compare the results. diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py new file mode 100644 index 0000000000000..1d49213cd4ab5 --- /dev/null +++ b/vllm/model_executor/custom_op.py @@ -0,0 +1,60 @@ +import torch.nn as nn + +from vllm.utils import is_cpu, is_hip + + +class CustomOp(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + """PyTorch-native implementation of the forward method. + + This method is optional. If implemented, it can be used with compilers + such as torch.compile or PyTorch XLA. Also, it can be used for testing + purposes. + """ + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + # By default, we assume that HIP ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_xpu(self, *args, **kwargs): + # By default, we assume that XPU ops are compatible with CUDA ops. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_cuda(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + # By default, we assume that CPU ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_tpu(self, *args, **kwargs): + # By default, we assume that TPU ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def forward_gaudi(self, *args, **kwargs): + # By default, we assume that Gaudi ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + # NOTE(woosuk): Here we assume that vLLM was built for only one + # specific backend. Currently, we do not support dynamic dispatching. + if is_hip(): + return self.forward_hip + elif is_cpu(): + return self.forward_cpu + else: + return self.forward_cuda diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index d101aa323b0e1..4d076421f9d2a 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -6,14 +6,14 @@ import torch.nn as nn import torch.nn.functional as F -from vllm import _custom_ops as ops from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.utils import set_weight_attrs -class SiluAndMul(nn.Module): +class SiluAndMul(CustomOp): """An activation function for SwiGLU. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. @@ -23,12 +23,14 @@ class SiluAndMul(nn.Module): return: (num_tokens, d) or (batch_size, seq_len, d) """ - def _forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) @@ -36,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out -class GeluAndMul(nn.Module): +class GeluAndMul(CustomOp): """An activation function for GeGLU. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. @@ -52,12 +54,14 @@ def __init__(self, approximate: str = "none"): if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") - def _forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) @@ -71,28 +75,32 @@ def extra_repr(self) -> str: return f'approximate={repr(self.approximate)}' -class NewGELU(nn.Module): +class NewGELU(CustomOp): - def _forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" c = math.sqrt(2.0 / math.pi) return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) ops.gelu_new(out, x) return out -class FastGELU(nn.Module): +class FastGELU(CustomOp): - def _forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + out = torch.empty_like(x) ops.gelu_fast(out, x) return out diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 8de0794158986..4533adf8f83aa 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -4,10 +4,10 @@ import torch import torch.nn as nn -from vllm import _custom_ops as ops +from vllm.model_executor.custom_op import CustomOp -class RMSNorm(nn.Module): +class RMSNorm(CustomOp): """Root mean square normalization. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. @@ -23,7 +23,7 @@ def __init__( self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - def _forward( + def forward_native( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, @@ -43,11 +43,13 @@ def _forward( else: return x, residual - def forward( + def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + from vllm import _custom_ops as ops + if residual is not None: ops.fused_add_rms_norm( x, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index d03903d206d33..d2652106b8441 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,7 +27,7 @@ import torch import torch.nn as nn -from vllm import _custom_ops as ops +from vllm.model_executor.custom_op import CustomOp def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: return x.flatten(-2) -class RotaryEmbedding(nn.Module): +class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" def __init__( @@ -93,7 +93,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: cache = torch.cat((cos, sin), dim=-1) return cache - def _forward( + def forward_native( self, positions: torch.Tensor, query: torch.Tensor, @@ -138,13 +138,15 @@ def _forward( key = key.flatten(-2) return query, key - def forward( + def forward_cuda( self, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding()