Skip to content

Dropout and Max Norm Regularization for LoRA training #545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 1, 2023
33 changes: 33 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,36 @@ def perlin_noise(noise, device, octaves):
noise += noise_perlin # broadcast for each batch
return noise / noise.std() # Scaled back to roughly unit variance
"""

def max_norm(state_dict, max_norm_value):
downkeys = []
upkeys = []
norms = []
keys_scaled = 0

for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down","lora_up"))
for i in range(len(downkeys)):
down = state_dict[downkeys[i]].cuda()
up = state_dict[upkeys[i]].cuda()
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
norm = updown.norm().clamp(min=max_norm_value/2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio **0.5
if ratio != 1:
keys_scaled +=1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm()*ratio
norms.append(scalednorm.item())

return keys_scaled, sum(norms)/len(norms), max(norms)

2 changes: 1 addition & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3638,4 +3638,4 @@ def __call__(self, examples):
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
return examples[0]
16 changes: 12 additions & 4 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class LoRAModule(torch.nn.Module):
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""

def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, dropout=None):
"""if alpha == 0 or None, alpha is rank (no scaling)."""
super().__init__()
self.lora_name = lora_name
Expand Down Expand Up @@ -60,14 +60,18 @@ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_

self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout

def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module

def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
if self.dropout:
return self.org_forward(x) + self.lora_up(torch.nn.functional.dropout(self.lora_down(x),p=self.dropout)) * self.multiplier * self.scale
else:
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale


class LoRAInfModule(LoRAModule):
Expand Down Expand Up @@ -348,7 +352,7 @@ def parse_block_lr_kwargs(nw_kwargs):
return down_lr_weight, mid_lr_weight, up_lr_weight


def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, dropout=None, **kwargs):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
Expand Down Expand Up @@ -403,6 +407,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_block_dims=conv_block_dims,
conv_block_alphas=conv_block_alphas,
varbose=True,
dropout=dropout,
)

if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
Expand Down Expand Up @@ -681,6 +686,7 @@ def __init__(
modules_alpha=None,
module_class=LoRAModule,
varbose=False,
dropout=None
) -> None:
"""
LoRA network: すごく引数が多いが、パターンは以下の通り
Expand All @@ -697,6 +703,8 @@ def __init__(
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
print(f"Neuron dropout: p={self.dropout}")

if modules_dim is not None:
print(f"create LoRA network from weights")
Expand Down Expand Up @@ -755,7 +763,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules
skipped.append(lora_name)
continue

lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, dropout)
loras.append(lora)
return loras, skipped

Expand Down
35 changes: 31 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
)
import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like, apply_noise_offset, max_norm


# TODO 他のスクリプトと共通化する
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.scale_weight_norms:
logs["keys_scaled"] = keys_scaled
logs["average_key_norm"] = mean_norm
logs["max_key_norm"] = maximum_norm

lrs = lr_scheduler.get_last_lr()

Expand Down Expand Up @@ -196,13 +200,14 @@ def train(args):
if args.dim_from_weights:
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
else:
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, args.dropout, **net_kwargs)
if network is None:
return

if hasattr(network, "prepare_network"):
network.prepare_network(args)


train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
Expand Down Expand Up @@ -375,6 +380,8 @@ def train(args):
"ss_face_crop_aug_range": args.face_crop_aug_range,
"ss_prior_loss_weight": args.prior_loss_weight,
"ss_min_snr_gamma": args.min_snr_gamma,
"ss_scale_weight_norms": args.scale_weight_norms,
"ss_dropout": args.dropout,
}

if use_user_config:
Expand Down Expand Up @@ -580,6 +587,7 @@ def remove_model(old_ckpt_name):
network.on_epoch_start(text_encoder, unet)

for step, batch in enumerate(train_dataloader):

current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, unet)
Expand Down Expand Up @@ -651,6 +659,10 @@ def remove_model(old_ckpt_name):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
Expand Down Expand Up @@ -686,9 +698,12 @@ def remove_model(old_ckpt_name):
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**max_mean_logs)


if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
Expand Down Expand Up @@ -787,6 +802,18 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
)
parser.add_argument(
"--scale_weight_norms",
type=float,
default=None,
help="Scale the weight of each key pair to help prevent overtraing via exploding gradients. (1 is a good starting point)",
)
parser.add_argument(
"--dropout",
type=float,
default=None,
help="Drops neurons out of training every step (0 is default behavior, 1 would drop all neurons)",
)
parser.add_argument(
"--base_weights",
type=str,
Expand Down