Skip to content

Commit

Permalink
[XPU] add lora optimization (#8527)
Browse files Browse the repository at this point in the history
* [XPU] add lora optimization

* fix

* refine
  • Loading branch information
dynamicheart committed Jul 2, 2024
1 parent 2723138 commit a53477c
Showing 1 changed file with 62 additions and 11 deletions.
73 changes: 62 additions & 11 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,27 +44,78 @@
from ...utils.distributed import distributed_gather
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
from ...utils.log import logger
from ...utils.tools import get_env_device
from .lora_config import LoRAConfig

try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
ColumnSequenceParallelLinear,
RowSequenceParallelLinear,
)

from .lora_layers import (
ColumnParallelLoRALinear,
ColumnParallelLoRAMergedLinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
LoRAMergedLinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
)
except:
pass


def get_lora_layers():
try:
if get_env_device() == "xpu":
# If paddle_xpu is not installed, just use PaddleNLP's native lora layers
from paddle_xpu.layers.nn.lora_layers import (
XPUColumnParallelLoRALinear as ColumnParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import (
XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear
from paddle_xpu.layers.nn.lora_layers import (
XPURowParallelLoRALinear as RowParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import (
XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear,
)

from .lora_layers import (
ColumnParallelLoRAMergedLinear,
LoRAConv2D,
LoRAMergedLinear,
)

else:
raise ImportError # Force to use the fallback if not XPU
except ImportError:
from .lora_layers import (
ColumnParallelLoRALinear,
ColumnParallelLoRAMergedLinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
LoRAMergedLinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
)

return {
"ColumnParallelLoRALinear": ColumnParallelLoRALinear,
"ColumnParallelLoRAMergedLinear": ColumnParallelLoRAMergedLinear,
"ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear,
"LoRAConv2D": LoRAConv2D,
"LoRALinear": LoRALinear,
"LoRAMergedLinear": LoRAMergedLinear,
"RowParallelLoRALinear": RowParallelLoRALinear,
"RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear,
}


lora_layers = get_lora_layers()
ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"]
ColumnParallelLoRAMergedLinear = lora_layers["ColumnParallelLoRAMergedLinear"]
ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"]
LoRAConv2D = lora_layers["LoRAConv2D"]
LoRALinear = lora_layers["LoRALinear"]
LoRAMergedLinear = lora_layers["LoRAMergedLinear"]
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]

try:
from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
Expand Down

0 comments on commit a53477c

Please sign in to comment.