Skip to content

Commit

Permalink
[XPU] add lora optimization (PaddlePaddle#8527)
Browse files Browse the repository at this point in the history
* [XPU] add lora optimization

* fix

* refine
  • Loading branch information
dynamicheart authored and FeixLiu committed Jul 26, 2024
1 parent 727ea59 commit e916bbd
Showing 1 changed file with 62 additions and 0 deletions.
62 changes: 62 additions & 0 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,68 @@ class RowSequenceParallelLinear:
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
]


def get_lora_layers():
try:
if get_env_device() == "xpu":
# If paddle_xpu is not installed, just use PaddleNLP's native lora layers
from paddle_xpu.layers.nn.lora_layers import (
XPUColumnParallelLoRALinear as ColumnParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import (
XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear
from paddle_xpu.layers.nn.lora_layers import (
XPURowParallelLoRALinear as RowParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import (
XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear,
)

from .lora_layers import (
ColumnParallelLoRAMergedLinear,
LoRAConv2D,
LoRAMergedLinear,
)

else:
raise ImportError # Force to use the fallback if not XPU
except ImportError:
from .lora_layers import (
ColumnParallelLoRALinear,
ColumnParallelLoRAMergedLinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
LoRAMergedLinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
)

return {
"ColumnParallelLoRALinear": ColumnParallelLoRALinear,
"ColumnParallelLoRAMergedLinear": ColumnParallelLoRAMergedLinear,
"ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear,
"LoRAConv2D": LoRAConv2D,
"LoRALinear": LoRALinear,
"LoRAMergedLinear": LoRAMergedLinear,
"RowParallelLoRALinear": RowParallelLoRALinear,
"RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear,
}


lora_layers = get_lora_layers()
ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"]
ColumnParallelLoRAMergedLinear = lora_layers["ColumnParallelLoRAMergedLinear"]
ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"]
LoRAConv2D = lora_layers["LoRAConv2D"]
LoRALinear = lora_layers["LoRALinear"]
LoRAMergedLinear = lora_layers["LoRAMergedLinear"]
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]

try:
from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
Expand Down

0 comments on commit e916bbd

Please sign in to comment.