@@ -27,29 +27,27 @@ def _reshape_activation_tensor(
2727
2828 @staticmethod
2929 def silu_and_mul (out : torch .Tensor , x : torch .Tensor ) -> None :
30- x1 , x2 = ipex_ops ._reshape_activation_tensor (x )
31- ipex .llm .functional .silu_mul (x1 , x2 , out )
30+ ipex .llm .functional .silu_and_mul (x , out )
3231
3332 @staticmethod
3433 def gelu_and_mul (out : torch .Tensor , x : torch .Tensor ) -> None :
35- x1 , x2 = ipex_ops ._reshape_activation_tensor (x )
36- ipex .llm .functional .gelu_mul (x1 , x2 , out , "none" )
34+ ipex .llm .functional .gelu_and_mul (x , out )
3735
3836 @staticmethod
3937 def gelu_tanh_and_mul (out : torch .Tensor , x : torch .Tensor ) -> None :
40- x1 , x2 = ipex_ops ._reshape_activation_tensor (x )
41- ipex .llm .functional .gelu_mul (x1 , x2 , out , "tanh" )
38+ ipex .llm .functional .gelu_and_mul (x , out )
4239
4340 @staticmethod
44- def gelu_fast (out : torch . Tensor , x : torch .Tensor ) -> None :
45- out . copy_ ( torch .nn .functional .gelu (x ) )
41+ def gelu_fast (x : torch .Tensor ) -> torch . Tensor :
42+ return torch .nn .functional .gelu (x )
4643
4744 @staticmethod
48- def gelu_new (out : torch . Tensor , x : torch .Tensor ) -> None :
49- out . copy_ ( torch .nn .functional .gelu (x ) )
45+ def gelu_new (x : torch .Tensor ) -> torch . Tensor :
46+ return torch .nn .functional .gelu (x )
5047
51- # TODO add implementation of gelu_quick here
52- # def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
48+ @staticmethod
49+ def gelu_quick (out : torch .Tensor , x : torch .Tensor ) -> None :
50+ ipex .llm .functional .gelu_quick (x , out )
5351
5452 @staticmethod
5553 def paged_attention_v1 (
@@ -160,67 +158,26 @@ def rotary_embedding(
160158 cos_sin_cache : torch .Tensor , # [cos_sin_dim, rot_dim]
161159 is_neox : bool ,
162160 ) -> None :
163- if positions .dim () == 1 :
164- positions = positions .unsqueeze (0 )
165- query = query .unsqueeze (0 )
166- key = key .unsqueeze (0 )
167-
168- rotary_dim = cos_sin_cache .size (1 )
169- query = query .view (* query .shape [:- 1 ], - 1 , head_size )
170- key = key .view (* key .shape [:- 1 ], - 1 , head_size )
171-
172- query_rot = query [..., :rotary_dim ]
173- key_rot = key [..., :rotary_dim ]
174-
175- cos_sin = cos_sin_cache [positions .long ()]
176- cos , sin = cos_sin .chunk (2 , dim = - 1 )
177-
178- if is_neox :
179- cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
180- sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
181- else :
182- cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
183- sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
184- ipex .llm .functional .rotary_embedding (query_rot , key_rot , sin , cos ,
185- rotary_dim , is_neox , positions )
161+ rot_dim = cos_sin_cache .size (1 )
162+ ipex .llm .functional .rotary_embedding_batched (positions , query , key ,
163+ head_size , cos_sin_cache ,
164+ is_neox , rot_dim )
186165
187166 @staticmethod
188167 def batched_rotary_embedding (positions : torch .Tensor , query : torch .Tensor ,
189168 key : torch .Tensor , head_size : int ,
190169 cos_sin_cache : torch .Tensor , is_neox : bool ,
191170 rot_dim : int ,
192171 cos_sin_cache_offsets : torch .Tensor ) -> None :
193- if positions .dim () == 1 :
194- positions = positions .unsqueeze (0 )
195- query = query .unsqueeze (0 )
196- key = key .unsqueeze (0 )
197- cos_sin_cache_offsets = cos_sin_cache_offsets .view_as (positions )
198- rotary_dim = cos_sin_cache .size (1 )
199- query = query .view (* query .shape [:- 1 ], - 1 , head_size )
200- key = key .view (* key .shape [:- 1 ], - 1 , head_size )
201-
202- query_rot = query [..., :rotary_dim ]
203- key_rot = key [..., :rotary_dim ]
204-
205- cos_sin = cos_sin_cache [torch .add (positions ,
206- cos_sin_cache_offsets ).long ()]
207- cos , sin = cos_sin .chunk (2 , dim = - 1 )
208-
209- if is_neox :
210- cos = cos .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
211- sin = sin .repeat (1 , 1 , 2 ).unsqueeze (- 2 )
212- else :
213- cos = cos .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
214- sin = sin .repeat_interleave (2 , dim = - 1 ).unsqueeze (- 2 )
215-
216- ipex .llm .functional .rotary_embedding (query_rot , key_rot , sin , cos ,
217- rotary_dim , is_neox , positions )
172+ ipex .llm .functional .rotary_embedding_batched (positions , query , key ,
173+ head_size , cos_sin_cache ,
174+ is_neox , rot_dim ,
175+ cos_sin_cache_offsets )
218176
219177 @staticmethod
220- def rms_norm (out : torch .Tensor , input : torch .Tensor , weight : torch .Tensor ,
221- epsilon : float ) -> None :
222- tmp = ipex .llm .functional .rms_norm (input , weight , epsilon )
223- out .copy_ (tmp )
178+ def rms_norm (input : torch .Tensor , weight : torch .Tensor ,
179+ epsilon : float ) -> torch .Tensor :
180+ return ipex .llm .functional .rms_norm (input , weight , epsilon )
224181
225182 @staticmethod
226183 def fused_add_rms_norm (input : torch .Tensor , residual : torch .Tensor ,
@@ -246,11 +203,14 @@ def varlen_attention(
246203 return_softmax : bool ,
247204 gen_ : torch .Generator ,
248205 ) -> None :
249- ipex .llm .functional .varlen_attention (query , key , value , out , seqlen_q ,
250- seqlen_k , max_seqlen_q ,
251- max_seqlen_k , pdropout ,
252- softmax_scale , zero_tensors ,
253- is_causal , return_softmax , gen_ )
206+ ipex .llm .functional .varlen_attention (query .contiguous (),
207+ key .contiguous (),
208+ value .contiguous (), out ,
209+ seqlen_q .int (), seqlen_k .int (),
210+ max_seqlen_q , max_seqlen_k ,
211+ pdropout , softmax_scale ,
212+ zero_tensors , is_causal ,
213+ return_softmax , gen_ )
254214
255215 @staticmethod
256216 def reshape_and_cache (
0 commit comments