diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 41ab1e681e24..e23f1b4b6d56 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -44,6 +44,7 @@ from ...utils.distributed import distributed_gather from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME from ...utils.log import logger +from ...utils.tools import get_env_device from .lora_config import LoRAConfig try: @@ -51,20 +52,70 @@ ColumnSequenceParallelLinear, RowSequenceParallelLinear, ) - - from .lora_layers import ( - ColumnParallelLoRALinear, - ColumnParallelLoRAMergedLinear, - ColumnSequenceParallelLoRALinear, - LoRAConv2D, - LoRALinear, - LoRAMergedLinear, - RowParallelLoRALinear, - RowSequenceParallelLoRALinear, - ) except: pass + +def get_lora_layers(): + try: + if get_env_device() == "xpu": + # If paddle_xpu is not installed, just use PaddleNLP's native lora layers + from paddle_xpu.layers.nn.lora_layers import ( + XPUColumnParallelLoRALinear as ColumnParallelLoRALinear, + ) + from paddle_xpu.layers.nn.lora_layers import ( + XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear, + ) + from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear + from paddle_xpu.layers.nn.lora_layers import ( + XPURowParallelLoRALinear as RowParallelLoRALinear, + ) + from paddle_xpu.layers.nn.lora_layers import ( + XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear, + ) + + from .lora_layers import ( + ColumnParallelLoRAMergedLinear, + LoRAConv2D, + LoRAMergedLinear, + ) + + else: + raise ImportError # Force to use the fallback if not XPU + except ImportError: + from .lora_layers import ( + ColumnParallelLoRALinear, + ColumnParallelLoRAMergedLinear, + ColumnSequenceParallelLoRALinear, + LoRAConv2D, + LoRALinear, + LoRAMergedLinear, + RowParallelLoRALinear, + RowSequenceParallelLoRALinear, + ) + + return { + "ColumnParallelLoRALinear": ColumnParallelLoRALinear, + "ColumnParallelLoRAMergedLinear": ColumnParallelLoRAMergedLinear, + "ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear, + "LoRAConv2D": LoRAConv2D, + "LoRALinear": LoRALinear, + "LoRAMergedLinear": LoRAMergedLinear, + "RowParallelLoRALinear": RowParallelLoRALinear, + "RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear, + } + + +lora_layers = get_lora_layers() +ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"] +ColumnParallelLoRAMergedLinear = lora_layers["ColumnParallelLoRAMergedLinear"] +ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"] +LoRAConv2D = lora_layers["LoRAConv2D"] +LoRALinear = lora_layers["LoRALinear"] +LoRAMergedLinear = lora_layers["LoRAMergedLinear"] +RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"] +RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"] + try: from ...quantization.quantization_linear import ( ColumnParallelQuantizationLinear,