Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoRA] add quick_lora #8106

Merged
merged 12 commits into from
Mar 25, 2024
6 changes: 6 additions & 0 deletions llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ class ModelArgument:
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})
use_quick_lora: bool = field(
default=False,
metadata={
"help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0."
},
)
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})

Expand Down
3 changes: 3 additions & 0 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def main():
weight_double_quant=model_args.weight_double_quant,
weight_double_quant_block_size=model_args.weight_double_quant_block_size,
)

if training_args.pipeline_parallel_degree > 1:
if data_args.eval_with_do_generation and training_args.do_eval:
raise ValueError("Plese set eval_with_do_generation to false in pipeline parallel mode.")
Expand Down Expand Up @@ -426,10 +427,12 @@ def neft_post_hook(module, input, output):
dtype=dtype,
do_qat=quant_args.do_qat,
base_model_name_or_path=model_args.model_name_or_path,
use_quick_lora=model_args.use_quick_lora,
)
model = LoRAModel(model, lora_config)
else:
model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path)

model.print_trainable_parameters()

def compute_metrics_do_generation(eval_preds):
Expand Down
15 changes: 15 additions & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List, Optional, Union

from ...utils.env import LORA_CONFIG_NAME
from ...utils.log import logger


@dataclass
Expand Down Expand Up @@ -77,6 +78,20 @@
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
)
use_quick_lora: bool = field(
default=False,
metadata={
"help": "Whether to use quick lora, The use of Quick LoRa will only take effect when lora_dropout is set to 0."
},
)

def __post_init__(self):
if self.use_quick_lora and self.lora_dropout > 0:
logger.warning(

Check warning on line 90 in paddlenlp/peft/lora/lora_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_config.py#L90

Added line #L90 was not covered by tests
"Quick LoRa is enabled, but lora_dropout is set to a non-zero value. "
"We will automatically set `use_quick_lora` to `False` to avoid potential inconsistencies."
)
self.use_quick_lora = False

Check warning on line 94 in paddlenlp/peft/lora/lora_config.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_config.py#L94

Added line #L94 was not covered by tests

@property
def __dict__(self):
Expand Down
130 changes: 94 additions & 36 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
RowParallelLinear,
)

from .lora_quick_layers import quick_lora

if "npu" in paddle.device.get_all_custom_device_type():
from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
else:
Expand All @@ -42,6 +44,7 @@
lora_alpha: int = 1,
lora_dropout: float = 0.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
rslora: bool = False,
lora_plus_scale: float = 1.0,
**kwargs
Expand Down Expand Up @@ -84,6 +87,11 @@

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
Expand All @@ -102,9 +110,13 @@
self.merged = True

def forward(self, input: paddle.Tensor, *args, **kwargs):
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
if not self.merged:
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling
if self.use_quick_lora:
# Use the quick lora implementation
result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling)

Check warning on line 115 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L115

Added line #L115 was not covered by tests
else:
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
if not self.merged:
result += (self.lora_dropout(input) @ self.lora_A @ self.lora_B) * self.scaling
return result

def extra_repr(self):
Expand All @@ -123,6 +135,7 @@
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
**kwargs
):
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -171,6 +184,11 @@

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

Check warning on line 187 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L187

Added line #L187 was not covered by tests

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

Check warning on line 191 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L191

Added line #L191 was not covered by tests

def train(self):
super().train()
Expand All @@ -194,33 +212,52 @@
else:
input_mp = x

# x @ W : [bz, in_f / ws] ===> [bz, out_f]
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
else:
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)

if self.use_quick_lora:

Check warning on line 215 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L215

Added line #L215 was not covered by tests
# Use the quick lora implementation
result_mp = quick_lora(

Check warning on line 217 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L217

Added line #L217 was not covered by tests
input_mp,
self.lora_A,
self.lora_B,
self.weight,
self.bias,
self.scaling,
is_row=True,
group=self.model_parallel_group,
world_size=self.world_size,
)
output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
else:
# x @ W : [bz, in_f / ws] ===> [bz, out_f]
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)

Check warning on line 237 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L236-L237

Added lines #L236 - L237 were not covered by tests
else:
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
output = mp_ops._mp_allreduce(

Check warning on line 240 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L239-L240

Added lines #L239 - L240 were not covered by tests
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)

if not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
input_mp = self.lora_dropout(input_mp) @ self.lora_A
# all reduce to keep Lora B's gradient on different gpu consistent
input_dup = mp_ops._mp_allreduce(
input_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
# @ B: [bz, r] ===> [bz, out_f]
delta_mp = (input_dup @ self.lora_B) * self.scaling
output += delta_mp
output = output + self.bias if self.bias is not None else output
if not self.merged:

Check warning on line 247 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L247

Added line #L247 was not covered by tests
# x @ A: [bz, in_f/ ws] ===> [bz, r]
input_mp = self.lora_dropout(input_mp) @ self.lora_A

Check warning on line 249 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L249

Added line #L249 was not covered by tests
# all reduce to keep Lora B's gradient on different gpu consistent
input_dup = mp_ops._mp_allreduce(

Check warning on line 251 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L251

Added line #L251 was not covered by tests
input_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
# @ B: [bz, r] ===> [bz, out_f]
delta_mp = (input_dup @ self.lora_B) * self.scaling
output += delta_mp
output = output + self.bias if self.bias is not None else output

Check warning on line 260 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L258-L260

Added lines #L258 - L260 were not covered by tests
return output

def extra_repr(self):
Expand All @@ -240,6 +277,7 @@
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
**kwargs
):
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
Expand Down Expand Up @@ -286,6 +324,11 @@

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

Check warning on line 327 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L327

Added line #L327 was not covered by tests

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

Check warning on line 331 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L331

Added line #L331 was not covered by tests

def train(self):
super().train()
Expand All @@ -304,22 +347,37 @@
self.merged = True

def forward(self, input: paddle.Tensor):
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group)
result_mp = res_mp + self.bias
if self.use_quick_lora:

Check warning on line 350 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L350

Added line #L350 was not covered by tests
# Use the quick lora implementation
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) if self.is_mp else input
result_mp = quick_lora(

Check warning on line 353 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L352-L353

Added lines #L352 - L353 were not covered by tests
input_mp,
self.lora_A,
self.lora_B,
self.weight,
self.bias,
self.scaling,
is_column=True,
group=self.model_parallel_group,
world_size=self.world_size,
)
else:
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = tmp * self.scaling
res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group)
result_mp = res_mp + self.bias

Check warning on line 367 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L366-L367

Added lines #L366 - L367 were not covered by tests
else:
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)

Check warning on line 370 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L369-L370

Added lines #L369 - L370 were not covered by tests

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = tmp * self.scaling

Check warning on line 376 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L372-L376

Added lines #L372 - L376 were not covered by tests
else:
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
result_mp += delta_mp

Check warning on line 380 in paddlenlp/peft/lora/lora_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/peft/lora/lora_layers.py#L378-L380

Added lines #L378 - L380 were not covered by tests

if self.gather_output and self.is_mp:
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
bias_attr=False if module.bias is None else None,
use_quick_lora=lora_config.use_quick_lora,
)
if isinstance(module, nn.Conv2D):
lora_module = LoRAConv2D(
Expand Down Expand Up @@ -422,6 +423,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
)
),
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora B matrix
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)
Expand All @@ -444,6 +446,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
merge_weights=lora_config.merge_weights,
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
Expand Down
Loading
Loading