Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,6 +2789,9 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power",
)
parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio")
parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio")
parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio")


def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool):
Expand Down
60 changes: 48 additions & 12 deletions networks/dylora.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,27 +406,63 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
logger.info(f"weights are merged")
"""

def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_B" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras,
default_lr if unet_lr is None else unet_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

return all_params

Expand Down
75 changes: 54 additions & 21 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,21 +1035,55 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
params.extend(lora.parameters())
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
print("NO LR skipping!")
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

if self.unet_loras:
if self.block_lr:
Expand All @@ -1063,21 +1097,20 @@ def enumerate_params(loras):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}

if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

return all_params

Expand Down
78 changes: 56 additions & 22 deletions networks/lora_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,22 +1033,54 @@ def get_lr_weight(self, lora: LoRAModule) -> float:
return lr_weight

# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
def prepare_optimizer_params(
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True)
all_params = []

def enumerate_params(loras: List[LoRAModule]):
params = []
def assemble_params(loras, lr, ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
# params.extend(lora.parameters())
params.extend(lora.get_trainable_params())
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)

return params

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
all_params.append(param_data)
params = assemble_params(
self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

if self.unet_loras:
if self.block_lr:
Expand All @@ -1062,21 +1094,20 @@ def enumerate_params(loras: List[LoRAModule]):

# blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items():
param_data = {"params": enumerate_params(block_loras)}

if unet_lr is not None:
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
elif default_lr is not None:
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
params = assemble_params(
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

else:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
params = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio
)
all_params.extend(params)

return all_params

Expand All @@ -1093,6 +1124,9 @@ def on_epoch_start(self, text_encoder, unet):
def get_trainable_params(self):
return self.parameters()

def get_trainable_named_params(self):
return self.named_parameters()

def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None
Expand Down
Loading