Skip to content

Commit 1933ab4

Browse files
committed
Fix default_lr being applied
1 parent c769160 commit 1933ab4

File tree

3 files changed

+64
-17
lines changed

3 files changed

+64
-17
lines changed

networks/dylora.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,14 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
407407
"""
408408

409409
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
410-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
410+
def prepare_optimizer_params(
411+
self,
412+
text_encoder_lr,
413+
unet_lr,
414+
default_lr,
415+
unet_lora_plus_ratio=None,
416+
text_encoder_lora_plus_ratio=None
417+
):
411418
self.requires_grad_(True)
412419
all_params = []
413420

@@ -442,11 +449,19 @@ def assemble_params(loras, lr, lora_plus_ratio):
442449
return params
443450

444451
if self.text_encoder_loras:
445-
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
452+
params = assemble_params(
453+
self.text_encoder_loras,
454+
text_encoder_lr if text_encoder_lr is not None else default_lr,
455+
text_encoder_lora_plus_ratio
456+
)
446457
all_params.extend(params)
447458

448459
if self.unet_loras:
449-
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
460+
params = assemble_params(
461+
self.unet_loras,
462+
default_lr if unet_lr is None else unet_lr,
463+
unet_lora_plus_ratio
464+
)
450465
all_params.extend(params)
451466

452467
return all_params

networks/lora.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10351035
return lr_weight
10361036

10371037
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1038-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
1038+
def prepare_optimizer_params(
1039+
self,
1040+
text_encoder_lr,
1041+
unet_lr,
1042+
default_lr,
1043+
unet_lora_plus_ratio=None,
1044+
text_encoder_lora_plus_ratio=None
1045+
):
10391046
self.requires_grad_(True)
10401047
all_params = []
10411048

@@ -1070,7 +1077,11 @@ def assemble_params(loras, lr, lora_plus_ratio):
10701077
return params
10711078

10721079
if self.text_encoder_loras:
1073-
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
1080+
params = assemble_params(
1081+
self.text_encoder_loras,
1082+
text_encoder_lr if text_encoder_lr is not None else default_lr,
1083+
text_encoder_lora_plus_ratio
1084+
)
10741085
all_params.extend(params)
10751086

10761087
if self.unet_loras:
@@ -1085,14 +1096,19 @@ def assemble_params(loras, lr, lora_plus_ratio):
10851096

10861097
# blockごとにパラメータを設定する
10871098
for idx, block_loras in block_idx_to_lora.items():
1088-
if unet_lr is not None:
1089-
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
1090-
elif default_lr is not None:
1091-
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
1099+
params = assemble_params(
1100+
block_loras,
1101+
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1102+
unet_lora_plus_ratio
1103+
)
10921104
all_params.extend(params)
10931105

10941106
else:
1095-
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
1107+
params = assemble_params(
1108+
self.unet_loras,
1109+
default_lr if unet_lr is None else unet_lr,
1110+
unet_lora_plus_ratio
1111+
)
10961112
all_params.extend(params)
10971113

10981114
return all_params

networks/lora_fa.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
10331033
return lr_weight
10341034

10351035
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036-
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None):
1036+
def prepare_optimizer_params(
1037+
self,
1038+
text_encoder_lr,
1039+
unet_lr,
1040+
default_lr,
1041+
unet_lora_plus_ratio=None,
1042+
text_encoder_lora_plus_ratio=None
1043+
):
10371044
self.requires_grad_(True)
10381045
all_params = []
10391046

@@ -1068,7 +1075,11 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
10681075
return params
10691076

10701077
if self.text_encoder_loras:
1071-
params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio)
1078+
params = assemble_params(
1079+
self.text_encoder_loras,
1080+
text_encoder_lr if text_encoder_lr is not None else default_lr,
1081+
text_encoder_lora_plus_ratio
1082+
)
10721083
all_params.extend(params)
10731084

10741085
if self.unet_loras:
@@ -1083,14 +1094,19 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio):
10831094

10841095
# blockごとにパラメータを設定する
10851096
for idx, block_loras in block_idx_to_lora.items():
1086-
if unet_lr is not None:
1087-
params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
1088-
elif default_lr is not None:
1089-
params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio)
1097+
params = assemble_params(
1098+
block_loras,
1099+
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
1100+
unet_lora_plus_ratio
1101+
)
10901102
all_params.extend(params)
10911103

10921104
else:
1093-
params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio)
1105+
params = assemble_params(
1106+
self.unet_loras,
1107+
default_lr if unet_lr is None else unet_lr,
1108+
unet_lora_plus_ratio
1109+
)
10941110
all_params.extend(params)
10951111

10961112
return all_params

0 commit comments

Comments
 (0)