Skip to content

Commit 3e1189f

Browse files
authored
Merge pull request kohya-ss#1233 from rockerBOO/lora-plus
Add LoRA+ support
2 parents 92252a1 + be34692 commit 3e1189f

File tree

5 files changed

+220
-74
lines changed

5 files changed

+220
-74
lines changed

library/train_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2920,6 +2920,9 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
29202920
default=1,
29212921
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
29222922
)
2923+
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
2924+
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
2925+
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")
29232926

29242927

29252928
def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):

networks/dylora.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -406,27 +406,63 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
406406
logger.info(f"weights are merged")
407407
"""
408408

409-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
409+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
410+
def prepare_optimizer_params(
411+
self,
412+
text_encoder_lr,
413+
unet_lr,
414+
default_lr,
415+
text_encoder_loraplus_ratio=None,
416+
unet_loraplus_ratio=None,
417+
loraplus_ratio=None
418+
):
410419
self.requires_grad_(True)
411420
all_params = []
412421

413-
def enumerate_params(loras):
414-
params = []
422+
def assemble_params(loras, lr, ratio):
423+
param_groups = {"lora": {}, "plus": {}}
415424
for lora in loras:
416-
params.extend(lora.parameters())
425+
for name, param in lora.named_parameters():
426+
if ratio is not None and "lora_B" in name:
427+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
428+
else:
429+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
430+
431+
params = []
432+
for key in param_groups.keys():
433+
param_data = {"params": param_groups[key].values()}
434+
435+
if len(param_data["params"]) == 0:
436+
continue
437+
438+
if lr is not None:
439+
if key == "plus":
440+
param_data["lr"] = lr * ratio
441+
else:
442+
param_data["lr"] = lr
443+
444+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
445+
continue
446+
447+
params.append(param_data)
448+
417449
return params
418450

419451
if self.text_encoder_loras:
420-
param_data = {"params": enumerate_params(self.text_encoder_loras)}
421-
if text_encoder_lr is not None:
422-
param_data["lr"] = text_encoder_lr
423-
all_params.append(param_data)
452+
params = assemble_params(
453+
self.text_encoder_loras,
454+
text_encoder_lr if text_encoder_lr is not None else default_lr,
455+
text_encoder_loraplus_ratio or loraplus_ratio
456+
)
457+
all_params.extend(params)
424458

425459
if self.unet_loras:
426-
param_data = {"params": enumerate_params(self.unet_loras)}
427-
if unet_lr is not None:
428-
param_data["lr"] = unet_lr
429-
all_params.append(param_data)
460+
params = assemble_params(
461+
self.unet_loras,
462+
default_lr if unet_lr is None else unet_lr,
463+
unet_loraplus_ratio or loraplus_ratio
464+
)
465+
all_params.extend(params)
430466

431467
return all_params
432468

networks/lora.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,21 +1034,55 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10341034
return lr_weight
10351035

10361036
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1037-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1037+
def prepare_optimizer_params(
1038+
self,
1039+
text_encoder_lr,
1040+
unet_lr,
1041+
default_lr,
1042+
text_encoder_loraplus_ratio=None,
1043+
unet_loraplus_ratio=None,
1044+
loraplus_ratio=None
1045+
):
10381046
self.requires_grad_(True)
10391047
all_params = []
10401048

1041-
def enumerate_params(loras):
1042-
params = []
1049+
def assemble_params(loras, lr, ratio):
1050+
param_groups = {"lora": {}, "plus": {}}
10431051
for lora in loras:
1044-
params.extend(lora.parameters())
1052+
for name, param in lora.named_parameters():
1053+
if ratio is not None and "lora_up" in name:
1054+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
1055+
else:
1056+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
1057+
1058+
params = []
1059+
for key in param_groups.keys():
1060+
param_data = {"params": param_groups[key].values()}
1061+
1062+
if len(param_data["params"]) == 0:
1063+
continue
1064+
1065+
if lr is not None:
1066+
if key == "plus":
1067+
param_data["lr"] = lr * ratio
1068+
else:
1069+
param_data["lr"] = lr
1070+
1071+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
1072+
print("NO LR skipping!")
1073+
continue
1074+
1075+
params.append(param_data)
1076+
10451077
return params
10461078

10471079
if self.text_encoder_loras:
1048-
param_data = {"params": enumerate_params(self.text_encoder_loras)}
1049-
if text_encoder_lr is not None:
1050-
param_data["lr"] = text_encoder_lr
1051-
all_params.append(param_data)
1080+
params = assemble_params(
1081+
self.text_encoder_loras,
1082+
text_encoder_lr if text_encoder_lr is not None else default_lr,
1083+
text_encoder_loraplus_ratio or loraplus_ratio
1084+
)
1085+
all_params.extend(params)
10521086

10531087
if self.unet_loras:
10541088
if self.block_lr:
@@ -1062,21 +1096,20 @@ def enumerate_params(loras):
10621096

10631097
# blockごとにパラメータを設定する
10641098
for idx, block_loras in block_idx_to_lora.items():
1065-
param_data = {"params": enumerate_params(block_loras)}
1066-
1067-
if unet_lr is not None:
1068-
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1069-
elif default_lr is not None:
1070-
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1071-
if ("lr" in param_data) and (param_data["lr"] == 0):
1072-
continue
1073-
all_params.append(param_data)
1099+
params = assemble_params(
1100+
block_loras,
1101+
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1102+
unet_loraplus_ratio or loraplus_ratio
1103+
)
1104+
all_params.extend(params)
10741105

10751106
else:
1076-
param_data = {"params": enumerate_params(self.unet_loras)}
1077-
if unet_lr is not None:
1078-
param_data["lr"] = unet_lr
1079-
all_params.append(param_data)
1107+
params = assemble_params(
1108+
self.unet_loras,
1109+
unet_lr if unet_lr is not None else default_lr,
1110+
unet_loraplus_ratio or loraplus_ratio
1111+
)
1112+
all_params.extend(params)
10801113

10811114
return all_params
10821115

networks/lora_fa.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,22 +1033,54 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331033
return lr_weight
10341034

10351035
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1036+
def prepare_optimizer_params(
1037+
self,
1038+
text_encoder_lr,
1039+
unet_lr,
1040+
default_lr,
1041+
text_encoder_loraplus_ratio=None,
1042+
unet_loraplus_ratio=None,
1043+
loraplus_ratio=None
1044+
):
10371045
self.requires_grad_(True)
10381046
all_params = []
10391047

1040-
def enumerate_params(loras: List[LoRAModule]):
1041-
params = []
1048+
def assemble_params(loras, lr, ratio):
1049+
param_groups = {"lora": {}, "plus": {}}
10421050
for lora in loras:
1043-
# params.extend(lora.parameters())
1044-
params.extend(lora.get_trainable_params())
1051+
for name, param in lora.named_parameters():
1052+
if ratio is not None and "lora_up" in name:
1053+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
1054+
else:
1055+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
1056+
1057+
params = []
1058+
for key in param_groups.keys():
1059+
param_data = {"params": param_groups[key].values()}
1060+
1061+
if len(param_data["params"]) == 0:
1062+
continue
1063+
1064+
if lr is not None:
1065+
if key == "plus":
1066+
param_data["lr"] = lr * ratio
1067+
else:
1068+
param_data["lr"] = lr
1069+
1070+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
1071+
continue
1072+
1073+
params.append(param_data)
1074+
10451075
return params
10461076

10471077
if self.text_encoder_loras:
1048-
param_data = {"params": enumerate_params(self.text_encoder_loras)}
1049-
if text_encoder_lr is not None:
1050-
param_data["lr"] = text_encoder_lr
1051-
all_params.append(param_data)
1078+
params = assemble_params(
1079+
self.text_encoder_loras,
1080+
text_encoder_lr if text_encoder_lr is not None else default_lr,
1081+
text_encoder_loraplus_ratio or loraplus_ratio
1082+
)
1083+
all_params.extend(params)
10521084

10531085
if self.unet_loras:
10541086
if self.block_lr:
@@ -1062,21 +1094,20 @@ def enumerate_params(loras: List[LoRAModule]):
10621094

10631095
# blockごとにパラメータを設定する
10641096
for idx, block_loras in block_idx_to_lora.items():
1065-
param_data = {"params": enumerate_params(block_loras)}
1066-
1067-
if unet_lr is not None:
1068-
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1069-
elif default_lr is not None:
1070-
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1071-
if ("lr" in param_data) and (param_data["lr"] == 0):
1072-
continue
1073-
all_params.append(param_data)
1097+
params = assemble_params(
1098+
block_loras,
1099+
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1100+
unet_loraplus_ratio or loraplus_ratio
1101+
)
1102+
all_params.extend(params)
10741103

10751104
else:
1076-
param_data = {"params": enumerate_params(self.unet_loras)}
1077-
if unet_lr is not None:
1078-
param_data["lr"] = unet_lr
1079-
all_params.append(param_data)
1105+
params = assemble_params(
1106+
self.unet_loras,
1107+
unet_lr if unet_lr is not None else default_lr,
1108+
unet_loraplus_ratio or loraplus_ratio
1109+
)
1110+
all_params.extend(params)
10801111

10811112
return all_params
10821113

@@ -1093,6 +1124,9 @@ def on_epoch_start(self, text_encoder, unet):
10931124
def get_trainable_params(self):
10941125
return self.parameters()
10951126

1127+
def get_trainable_named_params(self):
1128+
return self.named_parameters()
1129+
10961130
def save_weights(self, file, dtype, metadata):
10971131
if metadata is not None and len(metadata) == 0:
10981132
metadata = None

0 commit comments

Comments
 (0)