Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] add quick_lora #8106

Merged
merged 12 commits into from
Mar 25, 2024
7 changes: 6 additions & 1 deletion llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@ class ModelArgument:
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})

use_quick_lora: bool = field(
default=False,
metadata={
"help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0."
},
)
# prefix tuning related parameters
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."})
Expand Down
1 change: 1 addition & 0 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def neft_post_hook(module, input, output):
dtype=dtype,
do_qat=quant_args.do_qat,
base_model_name_or_path=model_args.model_name_or_path,
use_quick_lora=model_args.use_quick_lora,
)
model = LoRAModel(model, lora_config)
else:
Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List, Optional, Union

from ...utils.env import LORA_CONFIG_NAME
from ...utils.log import logger


@dataclass
Expand Down Expand Up @@ -75,6 +76,20 @@
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
)
use_quick_lora: bool = field(
default=False,
metadata={
"help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0."
},
)

def __post_init__(self):
if self.use_quick_lora and self.lora_dropout > 0:
logger.warning(

Check warning on line 88 in paddlenlp/peft/lora/lora_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_config.py#L88

Added line #L88 was not covered by tests
"Quick LoRa is enabled, but lora_dropout is set to a non-zero value. "
"We will automatically set `use_quick_lora` to `False` to avoid potential inconsistencies."
)
self.use_quick_lora = False

Check warning on line 92 in paddlenlp/peft/lora/lora_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_config.py#L92

Added line #L92 was not covered by tests

@property
def __dict__(self):
Expand Down
119 changes: 90 additions & 29 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
RowParallelLinear,
)

from .lora_quick_layers import quick_lora


class LoRALinear(nn.Linear):
# LoRA implemented in a dense layer
Expand All @@ -35,6 +37,7 @@
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -68,6 +71,11 @@

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
Expand All @@ -86,9 +94,12 @@
self.merged = True

def forward(self, input: paddle.Tensor, *args, **kwargs):
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
if not self.merged:
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling
if self.use_quick_lora:
result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling)

Check warning on line 98 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L98

Added line #L98 was not covered by tests
else:
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
if not self.merged:
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling
return result

def extra_repr(self):
Expand All @@ -105,6 +116,7 @@
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
**kwargs
):
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -146,6 +158,11 @@

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

Check warning on line 161 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L161

Added line #L161 was not covered by tests

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

Check warning on line 165 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L165

Added line #L165 was not covered by tests

def train(self):
super().train()
Expand All @@ -169,30 +186,48 @@
else:
input_mp = x

# x @ W : [bz, in_f / ws] ===> [bz, out_f]
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)

output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)

if not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
input_mp = self.lora_dropout(input_mp) @ self.lora_A
# all reduce to keep Lora B's gradient on different gpu consistent
input_dup = mp_ops._mp_allreduce(
if self.use_quick_lora:
result_mp = quick_lora(

Check warning on line 190 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L189-L190

Added lines #L189 - L190 were not covered by tests
input_mp,
self.lora_A,
self.lora_B,
self.weight,
self.bias,
self.scaling,
is_row=True,
group=self.model_parallel_group,
world_size=self.world_size,
)
output = mp_ops._mp_allreduce(

Check warning on line 201 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L201

Added line #L201 was not covered by tests
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
else:
# x @ W : [bz, in_f / ws] ===> [bz, out_f]
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
output = mp_ops._mp_allreduce(

Check warning on line 210 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L209-L210

Added lines #L209 - L210 were not covered by tests
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
# @ B: [bz, r] ===> [bz, out_f]
delta_mp = (input_dup @ self.lora_B) * self.scaling
output += delta_mp
output = output + self.bias if self.bias is not None else output

if not self.merged:

Check warning on line 217 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L217

Added line #L217 was not covered by tests
# x @ A: [bz, in_f/ ws] ===> [bz, r]
input_mp = self.lora_dropout(input_mp) @ self.lora_A

Check warning on line 219 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L219

Added line #L219 was not covered by tests
# all reduce to keep Lora B's gradient on different gpu consistent
input_dup = mp_ops._mp_allreduce(

Check warning on line 221 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L221

Added line #L221 was not covered by tests
input_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
# @ B: [bz, r] ===> [bz, out_f]
delta_mp = (input_dup @ self.lora_B) * self.scaling
output += delta_mp
output = output + self.bias if self.bias is not None else output

Check warning on line 230 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L228-L230

Added lines #L228 - L230 were not covered by tests
return output

def extra_repr(self):
Expand All @@ -210,6 +245,7 @@
lora_dropout: float = 0.0,
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
**kwargs
):
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -249,6 +285,11 @@

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

Check warning on line 288 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L288

Added line #L288 was not covered by tests

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

Check warning on line 292 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L292

Added line #L292 was not covered by tests

def train(self):
super().train()
Expand All @@ -267,14 +308,34 @@
self.merged = True

def forward(self, input: paddle.Tensor):
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
if self.is_mp:
input_mp = mp_ops._c_identity(

Check warning on line 312 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L311-L312

Added lines #L311 - L312 were not covered by tests
input,
group=self.model_parallel_group,
)
else:
input_mp = input
if self.use_quick_lora:

Check warning on line 318 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L317-L318

Added lines #L317 - L318 were not covered by tests
# Use the quick lora implementation
result_mp = quick_lora(

Check warning on line 320 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L320

Added line #L320 was not covered by tests
input_mp,
self.lora_A,
self.lora_B,
self.weight,
self.bias,
self.scaling,
is_column=True,
group=self.model_parallel_group,
world_size=self.world_size,
)
else:
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)

Check warning on line 332 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L332

Added line #L332 was not covered by tests

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp
if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp

Check warning on line 338 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L334-L338

Added lines #L334 - L338 were not covered by tests

if self.gather_output and self.is_mp:
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
lora_dropout=lora_config.lora_dropout,
merge_weights=lora_config.merge_weights,
bias_attr=False if module.bias is None else None,
use_quick_lora=lora_config.use_quick_lora,
)
if isinstance(module, nn.Conv2D):
lora_module = LoRAConv2D(
Expand Down Expand Up @@ -418,6 +419,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
)
),
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora B matrix
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)
Expand All @@ -438,6 +440,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
merge_weights=lora_config.merge_weights,
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
Expand Down
Loading
Loading