@@ -30,6 +30,8 @@ class FatreluAndMul(CustomOp):
3030 def __init__ (self , threshold : float = 0. ):
3131 super ().__init__ ()
3232 self .threshold = threshold
33+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
34+ self .op = torch .ops ._C .fatrelu_and_mul
3335
3436 def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
3537 d = x .shape [- 1 ] // 2
@@ -39,12 +41,10 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
3941 return x1 * x2
4042
4143 def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
42- from vllm import _custom_ops as ops
43-
4444 d = x .shape [- 1 ] // 2
4545 output_shape = (x .shape [:- 1 ] + (d , ))
4646 out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
47- ops . fatrelu_and_mul (out , x , self .threshold )
47+ self . op (out , x , self .threshold )
4848 return out
4949
5050
@@ -103,34 +103,35 @@ def __init__(self, approximate: str = "none"):
103103 self .approximate = approximate
104104 if approximate not in ("none" , "tanh" ):
105105 raise ValueError (f"Unknown approximate mode: { approximate } " )
106+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
107+ if approximate == "none" :
108+ self .op = torch .ops ._C .gelu_and_mul
109+ elif approximate == "tanh" :
110+ self .op = torch .ops ._C .gelu_tanh_and_mul
111+ elif current_platform .is_xpu ():
112+ from vllm ._ipex_ops import ipex_ops
113+ if approximate == "none" :
114+ self .op = ipex_ops .gelu_and_mul
115+ else :
116+ self .op = ipex_ops .gelu_tanh_and_mul
106117
107118 def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
108119 """PyTorch-native implementation equivalent to forward()."""
109120 d = x .shape [- 1 ] // 2
110121 return F .gelu (x [..., :d ], approximate = self .approximate ) * x [..., d :]
111122
112123 def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
113- from vllm import _custom_ops as ops
114-
115124 d = x .shape [- 1 ] // 2
116125 output_shape = (x .shape [:- 1 ] + (d , ))
117126 out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
118- if self .approximate == "none" :
119- ops .gelu_and_mul (out , x )
120- elif self .approximate == "tanh" :
121- ops .gelu_tanh_and_mul (out , x )
127+ self .op (out , x )
122128 return out
123129
124130 def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
125- from vllm ._ipex_ops import ipex_ops as ops
126-
127131 d = x .shape [- 1 ] // 2
128132 output_shape = (x .shape [:- 1 ] + (d , ))
129133 out = torch .empty (output_shape , dtype = x .dtype , device = x .device )
130- if self .approximate == "none" :
131- ops .gelu_and_mul (out , x )
132- elif self .approximate == "tanh" :
133- ops .gelu_tanh_and_mul (out , x )
134+ self .op (out , x )
134135 return out
135136
136137 def extra_repr (self ) -> str :
@@ -140,65 +141,77 @@ def extra_repr(self) -> str:
140141@CustomOp .register ("gelu_new" )
141142class NewGELU (CustomOp ):
142143
144+ def __init__ (self ):
145+ super ().__init__ ()
146+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
147+ self .op = torch .ops ._C .gelu_new
148+ elif current_platform .is_xpu ():
149+ from vllm ._ipex_ops import ipex_ops
150+ self .op = ipex_ops .gelu_new
151+
143152 def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
144153 """PyTorch-native implementation equivalent to forward()."""
145154 c = math .sqrt (2.0 / math .pi )
146155 return 0.5 * x * (1.0 + torch .tanh (c *
147156 (x + 0.044715 * torch .pow (x , 3.0 ))))
148157
149158 def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
150- from vllm import _custom_ops as ops
151-
152159 out = torch .empty_like (x )
153- ops . gelu_new (out , x )
160+ self . op (out , x )
154161 return out
155162
156163 def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
157- from vllm ._ipex_ops import ipex_ops as ops
158-
159- return ops .gelu_new (x )
164+ return self .op (x )
160165
161166
162167@CustomOp .register ("gelu_fast" )
163168class FastGELU (CustomOp ):
164169
170+ def __init__ (self ):
171+ super ().__init__ ()
172+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
173+ self .op = torch .ops ._C .gelu_fast
174+ elif current_platform .is_xpu ():
175+ from vllm ._ipex_ops import ipex_ops
176+ self .op = ipex_ops .gelu_fast
177+
165178 def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
166179 """PyTorch-native implementation equivalent to forward()."""
167180 return 0.5 * x * (1.0 + torch .tanh (x * 0.7978845608 *
168181 (1.0 + 0.044715 * x * x )))
169182
170183 def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
171- from vllm import _custom_ops as ops
172-
173184 out = torch .empty_like (x )
174- ops . gelu_fast (out , x )
185+ self . op (out , x )
175186 return out
176187
177188 def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
178- from vllm ._ipex_ops import ipex_ops as ops
179-
180- return ops .gelu_fast (x )
189+ return self .op (x )
181190
182191
183192@CustomOp .register ("quick_gelu" )
184193class QuickGELU (CustomOp ):
185194 # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
195+ def __init__ (self ):
196+ super ().__init__ ()
197+ if current_platform .is_cuda_alike () or current_platform .is_cpu ():
198+ self .op = torch .ops ._C .gelu_quick
199+ elif current_platform .is_xpu ():
200+ from vllm ._ipex_ops import ipex_ops
201+ self .op = ipex_ops .gelu_quick
202+
186203 def forward_native (self , x : torch .Tensor ) -> torch .Tensor :
187204 """PyTorch-native implementation equivalent to forward()."""
188205 return x * torch .sigmoid (1.702 * x )
189206
190207 def forward_cuda (self , x : torch .Tensor ) -> torch .Tensor :
191- from vllm import _custom_ops as ops
192-
193208 out = torch .empty_like (x )
194- ops . gelu_quick (out , x )
209+ self . op (out , x )
195210 return out
196211
197212 def forward_xpu (self , x : torch .Tensor ) -> torch .Tensor :
198- from vllm ._ipex_ops import ipex_ops as ops
199-
200213 out = torch .empty_like (x )
201- ops . gelu_quick (out , x )
214+ self . op (out , x )
202215 return out
203216
204217 # TODO implement forward_xpu for QuickGELU
0 commit comments