11# SPDX-License-Identifier: Apache-2.0
22
3- from typing import Optional
3+ from typing import Optional , List
44
55import torch
66
@@ -295,3 +295,125 @@ def copy_blocks(key_caches: list[torch.Tensor],
295295 def swap_blocks (src : torch .Tensor , dst : torch .Tensor ,
296296 block_mapping : torch .Tensor ) -> None :
297297 torch .xpu .swap_blocks (src , dst , block_mapping ) # type: ignore
298+
299+ @staticmethod
300+ def bgmv_shrink (inputs : torch .Tensor ,
301+ lora_a_weights : torch .Tensor ,
302+ output_tensor : torch .Tensor ,
303+ lora_indices_tensor : torch .Tensor ,
304+ scaling : float = 1.0 ) -> None :
305+ ipex .llm .functional .bgmv_shrink (inputs , lora_a_weights , output_tensor ,
306+ lora_indices_tensor , scaling )
307+
308+ @staticmethod
309+ def bgmv_expand (inputs : torch .Tensor ,
310+ lora_b_weights : torch .Tensor ,
311+ output_tensor : torch .Tensor ,
312+ lora_indices_tensor : torch .Tensor ,
313+ add_inputs : bool = True ) -> None :
314+ ipex .llm .functional .bgmv_expand (inputs , lora_b_weights , output_tensor ,
315+ lora_indices_tensor , add_inputs )
316+
317+ @staticmethod
318+ def bgmv_expand_slice (inputs : torch .Tensor ,
319+ lora_b_weights : torch .Tensor ,
320+ output_tensor : torch .Tensor ,
321+ lora_indices_tensor : torch .Tensor ,
322+ slice_offset : int ,
323+ slice_size : int ,
324+ add_inputs : bool = True ) -> None :
325+ ipex .llm .functional .bgmv_expand_slice (inputs , lora_b_weights ,
326+ output_tensor ,
327+ lora_indices_tensor ,
328+ slice_offset , slice_size ,
329+ add_inputs )
330+
331+ @staticmethod
332+ def sgmv_shrink (inputs : torch .Tensor ,
333+ lora_a_weights : torch .Tensor ,
334+ output_tensor : torch .Tensor ,
335+ b_seq_start_loc : torch .Tensor ,
336+ seq_len_tensor : torch .Tensor ,
337+ lora_indices_tensor : torch .Tensor ,
338+ batches : int ,
339+ max_seq_length : int ,
340+ token_nums : int ,
341+ scaling : float = 1.0 ) -> None :
342+ assert inputs .size (0 ) == token_nums
343+ ipex .llm .functional .sgmv_shrink (inputs , lora_a_weights , output_tensor ,
344+ b_seq_start_loc , seq_len_tensor ,
345+ lora_indices_tensor , batches ,
346+ max_seq_length , scaling )
347+
348+ @staticmethod
349+ def sgmv_expand (inputs : torch .Tensor ,
350+ lora_b_weights : torch .Tensor ,
351+ output_tensor : torch .Tensor ,
352+ b_seq_start_loc : torch .Tensor ,
353+ seq_len_tensor : torch .Tensor ,
354+ lora_indices_tensor : torch .Tensor ,
355+ batches : int ,
356+ max_seq_length : int ,
357+ token_nums : int ,
358+ add_inputs : bool = False ) -> None :
359+ assert inputs .size (0 ) == token_nums
360+ ipex .llm .functional .sgmv_expand (inputs , lora_b_weights , output_tensor ,
361+ b_seq_start_loc , seq_len_tensor ,
362+ lora_indices_tensor , batches ,
363+ max_seq_length , add_inputs )
364+
365+ @staticmethod
366+ def sgmv_expand_slice (inputs : torch .Tensor ,
367+ lora_b_weights : torch .Tensor ,
368+ output_tensor : torch .Tensor ,
369+ b_seq_start_loc : torch .Tensor ,
370+ seq_len_tensor : torch .Tensor ,
371+ lora_indices_tensor : torch .Tensor ,
372+ batches : int ,
373+ max_seq_length : int ,
374+ token_nums : int ,
375+ slice_offset : int ,
376+ slice_size : int ,
377+ add_inputs : bool = False ) -> None :
378+ assert inputs .size (0 ) == token_nums
379+ ipex .llm .functional .sgmv_expand_slice (inputs , lora_b_weights ,
380+ output_tensor , b_seq_start_loc ,
381+ seq_len_tensor ,
382+ lora_indices_tensor , batches ,
383+ max_seq_length , slice_offset ,
384+ slice_size , add_inputs )
385+
386+ # @staticmethod
387+ # def lora_expand(inputs: torch.Tensor,
388+ # lora_b_weights: List[torch.Tensor],
389+ # output_tensor: torch.Tensor,
390+ # token_lora_mapping: torch.Tensor,
391+ # token_indices_sorted_by_lora_ids: torch.Tensor,
392+ # num_tokens_per_lora: torch.Tensor,
393+ # lora_token_start_loc: torch.Tensor,
394+ # lora_ids: torch.Tensor,
395+ # offset_start: int = 0,
396+ # add_inputs: bool = False) -> None:
397+ # ipex.llm.functional.lora_expand(inputs, lora_b_weights,
398+ # output_tensor, token_lora_mapping,
399+ # token_indices_sorted_by_lora_ids,
400+ # num_tokens_per_lora, num_tokens_per_lora,
401+ # lora_token_start_loc, lora_ids,
402+ # offset_start, add_inputs)
403+
404+ # @staticmethod
405+ # def lora_shrink(inputs: torch.Tensor,
406+ # lora_a_weights: List[torch.Tensor],
407+ # output_tensor: torch.Tensor,
408+ # token_lora_mapping: torch.Tensor,
409+ # token_indices_sorted_by_lora_ids: torch.Tensor,
410+ # num_tokens_per_lora: torch.Tensor,
411+ # lora_token_start_loc: torch.Tensor,
412+ # lora_ids: torch.Tensor,
413+ # scaling: float) -> None:
414+ # ipex.llm.functional.lora_shrink(inputs, lora_a_weights,
415+ # output_tensor, token_lora_mapping,
416+ # token_indices_sorted_by_lora_ids,
417+ # num_tokens_per_lora, num_tokens_per_lora,
418+ # lora_token_start_loc, lora_ids,
419+ # scaling)
0 commit comments