Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pissa #8250

Merged
merged 6 commits into from
Apr 11, 2024
Merged

Pissa #8250

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llm/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class ModelArgument:
)
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})

# prefix tuning related parameters
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
Expand Down
1 change: 1 addition & 0 deletions llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ def neft_post_hook(module, input, output):
lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4,
rslora=model_args.rslora,
lora_plus_scale=model_args.lora_plus_scale,
pissa=model_args.pissa,
merge_weights=False,
tensor_parallel_degree=training_args.tensor_parallel_degree,
dtype=dtype,
Expand Down
2 changes: 1 addition & 1 deletion llm/llama/lora_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
"lora": true,
"zero_padding": false,
"use_flash_attention": false
}
}
33 changes: 33 additions & 0 deletions llm/llama/lora_argument_pissa.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"model_name_or_path": "facebook/llama-7b",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/llama_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 32,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 2e-05,
"warmup_steps": 10,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"fp16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"lora": true,
"pissa": false,
"zero_padding": false,
"use_flash_attention": false
}
33 changes: 33 additions & 0 deletions llm/qwen/lora_argument_pissa.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"model_name_or_path": "qwen/qwen-7b",
"dataset_name_or_path": "./data",
"output_dir": "./checkpoints/qwen_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 32,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
"num_train_epochs": 3,
"learning_rate": 2e-05,
"warmup_steps": 10,
"logging_steps": 1,
"evaluation_strategy": "epoch",
"save_strategy": "epoch",
"src_length": 1024,
"max_length": 2048,
"bf16": true,
"fp16_opt_level": "O2",
"do_train": true,
"do_eval": true,
"disable_tqdm": true,
"load_best_model_at_end": true,
"eval_with_do_generation": false,
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"lora": true,
"pissa": true,
"zero_padding": false,
"use_flash_attention": false
}
1 change: 1 addition & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class LoRAConfig:
)
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"})
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
Expand Down
40 changes: 39 additions & 1 deletion paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
use_quick_lora: bool = False,
rslora: bool = False,
lora_plus_scale: float = 1.0,
pissa: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
Expand All @@ -62,6 +63,7 @@ def __init__(
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
self.pissa = pissa

# Actual trainable parameters
self.lora_A = self.create_parameter(
Expand All @@ -79,9 +81,12 @@ def __init__(
learning_rate=lora_plus_scale,
),
)
self.apply_pissa = False

if not rslora:
if not rslora and not pissa:
self.scaling = self.lora_alpha / self.r
elif pissa:
self.scaling = 1.0
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

Expand All @@ -93,6 +98,25 @@ def __init__(
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def pissa_init(self, rank):
weight = self.weight
dtype = weight.dtype
if dtype != paddle.float32:
weight = weight.astype(paddle.float32)

U, S, Vh = paddle.linalg.svd(weight.data, full_matrices=False)
Ur = U[:, :rank]
Sr = S[:rank]
Vhr = Vh[:rank]

lora_A = Ur @ paddle.diag(paddle.sqrt(Sr))
lora_B = paddle.diag(paddle.sqrt(Sr)) @ Vhr
self.lora_A.set_value(lora_A.astype(dtype))
self.lora_B.set_value(lora_B.astype(dtype))
res = weight.data - lora_A @ lora_B
weight = res.astype(dtype)
self.weight.set_value(weight)

def train(self):
super().train()
if self.merge_weights and self.merged:
Expand All @@ -110,6 +134,10 @@ def eval(self):
self.merged = True

def forward(self, input: paddle.Tensor, *args, **kwargs):
if not self.apply_pissa and self.pissa:
self.pissa_init(self.r)
self.apply_pissa = True

if self.use_quick_lora:
# Use the quick lora implementation
result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling)
Expand All @@ -136,11 +164,16 @@ def __init__(
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
pissa: bool = False,
**kwargs
):
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")

if pissa:
raise ValueError("Pissa is not supported in model parallel by now")

self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
Expand Down Expand Up @@ -278,11 +311,16 @@ def __init__(
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
pissa: bool = False,
**kwargs
):
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")

if pissa:
raise ValueError("Pissa is not supported in model parallel by now")

self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
merge_weights=lora_config.merge_weights,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
pissa=lora_config.pissa,
bias_attr=False if module.bias is None else None,
use_quick_lora=lora_config.use_quick_lora,
)
Expand Down Expand Up @@ -417,6 +418,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
pissa=lora_config.pissa,
merge_weights=lora_config.merge_weights,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(
Expand Down Expand Up @@ -445,6 +447,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
pissa=lora_config.pissa,
merge_weights=lora_config.merge_weights,
use_quick_lora=lora_config.use_quick_lora,
)
Expand Down
Loading