Skip to content

Commit c6ea654

Browse files
DboyqiaoLuFinchyma11
committed
Enable multi-LoRA (vllm-project#147)
* enable lora on xpu * adjust python api * fix lora Signed-off-by: yan ma <yan.ma@intel.com> * Porting multi-LoRA to new code base --------- Signed-off-by: yan ma <yan.ma@intel.com> Co-authored-by: fengqing.lu <fengqing.lu@intel.com> Co-authored-by: yan ma <yan.ma@intel.com>
1 parent b3f7431 commit c6ea654

File tree

7 files changed

+1381
-69
lines changed

7 files changed

+1381
-69
lines changed

examples/offline_inference/multilora_inference.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ def initialize_engine() -> LLMEngine:
9898
# numbers will cause higher memory usage. If you know that all LoRAs will
9999
# use the same rank, it is recommended to set this as low as possible.
100100
# max_cpu_loras: controls the size of the CPU LoRA cache.
101-
engine_args = EngineArgs(
102-
model="meta-llama/Llama-2-7b-hf",
103-
enable_lora=True,
104-
max_loras=1,
105-
max_lora_rank=8,
106-
max_cpu_loras=2,
107-
max_num_seqs=256,
108-
)
101+
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
102+
enable_lora=True,
103+
max_loras=1,
104+
max_lora_rank=8,
105+
max_cpu_loras=2,
106+
max_num_seqs=256,
107+
enforce_eager=True,
108+
block_size=64)
109109
return LLMEngine.from_engine_args(engine_args)
110110

111111

vllm/_ipex_ops.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Optional
3+
from typing import Optional, List
44

55
import torch
66

@@ -295,3 +295,125 @@ def copy_blocks(key_caches: list[torch.Tensor],
295295
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
296296
block_mapping: torch.Tensor) -> None:
297297
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
298+
299+
@staticmethod
300+
def bgmv_shrink(inputs: torch.Tensor,
301+
lora_a_weights: torch.Tensor,
302+
output_tensor: torch.Tensor,
303+
lora_indices_tensor: torch.Tensor,
304+
scaling: float = 1.0) -> None:
305+
ipex.llm.functional.bgmv_shrink(inputs, lora_a_weights, output_tensor,
306+
lora_indices_tensor, scaling)
307+
308+
@staticmethod
309+
def bgmv_expand(inputs: torch.Tensor,
310+
lora_b_weights: torch.Tensor,
311+
output_tensor: torch.Tensor,
312+
lora_indices_tensor: torch.Tensor,
313+
add_inputs: bool = True) -> None:
314+
ipex.llm.functional.bgmv_expand(inputs, lora_b_weights, output_tensor,
315+
lora_indices_tensor, add_inputs)
316+
317+
@staticmethod
318+
def bgmv_expand_slice(inputs: torch.Tensor,
319+
lora_b_weights: torch.Tensor,
320+
output_tensor: torch.Tensor,
321+
lora_indices_tensor: torch.Tensor,
322+
slice_offset: int,
323+
slice_size: int,
324+
add_inputs: bool = True) -> None:
325+
ipex.llm.functional.bgmv_expand_slice(inputs, lora_b_weights,
326+
output_tensor,
327+
lora_indices_tensor,
328+
slice_offset, slice_size,
329+
add_inputs)
330+
331+
@staticmethod
332+
def sgmv_shrink(inputs: torch.Tensor,
333+
lora_a_weights: torch.Tensor,
334+
output_tensor: torch.Tensor,
335+
b_seq_start_loc: torch.Tensor,
336+
seq_len_tensor: torch.Tensor,
337+
lora_indices_tensor: torch.Tensor,
338+
batches: int,
339+
max_seq_length: int,
340+
token_nums: int,
341+
scaling: float = 1.0) -> None:
342+
assert inputs.size(0) == token_nums
343+
ipex.llm.functional.sgmv_shrink(inputs, lora_a_weights, output_tensor,
344+
b_seq_start_loc, seq_len_tensor,
345+
lora_indices_tensor, batches,
346+
max_seq_length, scaling)
347+
348+
@staticmethod
349+
def sgmv_expand(inputs: torch.Tensor,
350+
lora_b_weights: torch.Tensor,
351+
output_tensor: torch.Tensor,
352+
b_seq_start_loc: torch.Tensor,
353+
seq_len_tensor: torch.Tensor,
354+
lora_indices_tensor: torch.Tensor,
355+
batches: int,
356+
max_seq_length: int,
357+
token_nums: int,
358+
add_inputs: bool = False) -> None:
359+
assert inputs.size(0) == token_nums
360+
ipex.llm.functional.sgmv_expand(inputs, lora_b_weights, output_tensor,
361+
b_seq_start_loc, seq_len_tensor,
362+
lora_indices_tensor, batches,
363+
max_seq_length, add_inputs)
364+
365+
@staticmethod
366+
def sgmv_expand_slice(inputs: torch.Tensor,
367+
lora_b_weights: torch.Tensor,
368+
output_tensor: torch.Tensor,
369+
b_seq_start_loc: torch.Tensor,
370+
seq_len_tensor: torch.Tensor,
371+
lora_indices_tensor: torch.Tensor,
372+
batches: int,
373+
max_seq_length: int,
374+
token_nums: int,
375+
slice_offset: int,
376+
slice_size: int,
377+
add_inputs: bool = False) -> None:
378+
assert inputs.size(0) == token_nums
379+
ipex.llm.functional.sgmv_expand_slice(inputs, lora_b_weights,
380+
output_tensor, b_seq_start_loc,
381+
seq_len_tensor,
382+
lora_indices_tensor, batches,
383+
max_seq_length, slice_offset,
384+
slice_size, add_inputs)
385+
386+
# @staticmethod
387+
# def lora_expand(inputs: torch.Tensor,
388+
# lora_b_weights: List[torch.Tensor],
389+
# output_tensor: torch.Tensor,
390+
# token_lora_mapping: torch.Tensor,
391+
# token_indices_sorted_by_lora_ids: torch.Tensor,
392+
# num_tokens_per_lora: torch.Tensor,
393+
# lora_token_start_loc: torch.Tensor,
394+
# lora_ids: torch.Tensor,
395+
# offset_start: int = 0,
396+
# add_inputs: bool = False) -> None:
397+
# ipex.llm.functional.lora_expand(inputs, lora_b_weights,
398+
# output_tensor, token_lora_mapping,
399+
# token_indices_sorted_by_lora_ids,
400+
# num_tokens_per_lora, num_tokens_per_lora,
401+
# lora_token_start_loc, lora_ids,
402+
# offset_start, add_inputs)
403+
404+
# @staticmethod
405+
# def lora_shrink(inputs: torch.Tensor,
406+
# lora_a_weights: List[torch.Tensor],
407+
# output_tensor: torch.Tensor,
408+
# token_lora_mapping: torch.Tensor,
409+
# token_indices_sorted_by_lora_ids: torch.Tensor,
410+
# num_tokens_per_lora: torch.Tensor,
411+
# lora_token_start_loc: torch.Tensor,
412+
# lora_ids: torch.Tensor,
413+
# scaling: float) -> None:
414+
# ipex.llm.functional.lora_shrink(inputs, lora_a_weights,
415+
# output_tensor, token_lora_mapping,
416+
# token_indices_sorted_by_lora_ids,
417+
# num_tokens_per_lora, num_tokens_per_lora,
418+
# lora_token_start_loc, lora_ids,
419+
# scaling)

0 commit comments

Comments
 (0)