Skip to content

Commit

Permalink
Add support for EDM2 timestep weighting network
Browse files Browse the repository at this point in the history
  • Loading branch information
cheald committed Apr 7, 2024
1 parent 1cf0133 commit 0afee96
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 0 deletions.
110 changes: 110 additions & 0 deletions library/timestep_uncertainty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# adapted from https://github.com/NVlabs/edm2/blob/3a6682d3d25395df64863d3cea563bf3f3380769/training/networks_edm2.py

import torch
import numpy as np
import os
from safetensors.torch import load_file

#----------------------------------------------------------------------------
# Normalize given tensor to unit magnitude with respect to the given
# dimensions. Default = all dimensions except the first.

def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)


class MPFourier(torch.nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))

def forward(self, x):
y = x.to(torch.float32)
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32)
y = y.cos() * np.sqrt(2)
return y.to(x.dtype)

class MPConv(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))

def forward(self, x, gain=1):
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))

class TimestepUncertaintyLossNetwork(torch.nn.Module):
def __init__(self,
logvar_channels = 128, # Intermediate dimensionality for uncertainty estimation.
):
super().__init__()
self.logvar_fourier = MPFourier(logvar_channels)
self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])

def forward(self, sigma):
c_noise = sigma.reshape(-1, 1, 1, 1).flatten().log() / 4
logvar = self.logvar_linear(self.logvar_fourier(c_noise)).reshape(-1, 1, 1, 1)
return logvar

def loss(self, sigma, loss):
logvar = self.forward(sigma)
return loss / logvar.exp() + logvar

def load_weights(self, file, dtype=None):
if not os.path.exists(file):
print(f"WARNING: Could not load weights from '{file}' because the file does not exist.")
return

if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
state_dict = load_file(file)
else:
state_dict = torch.load(file)

if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to(dtype)
state_dict[key] = v

self.load_state_dict(state_dict)

def save_weights(self, file, dtype=torch.float32, metadata={}):
metadata = {}

state_dict = self.state_dict()

if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v

if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util

# Precalculate model hashes to save time on indexing
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash

save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)
10 changes: 10 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3374,6 +3374,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="If set, dynamically learn the value for `multires_noise_discount`. 7e-2..5e-2 is a good starting point",
)
parser.add_argument(
"--sigma_uncertainty_model",
type=str,
default=None,
)
parser.add_argument(
"--train_sigma_uncertainty",
action="store_true",
help="Train sigma uncertainty"
)

if support_dreambooth:
# DreamBooth training
Expand Down
19 changes: 19 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.timestep_uncertainty import TimestepUncertaintyLossNetwork

init_ipex()

Expand Down Expand Up @@ -368,6 +369,16 @@ def train(self, args):
network.get_parameter("noise_discount"),
], "lr": args.multires_discount_lr * args.gradient_accumulation_steps})

if args.sigma_uncertainty_model:
timestep_uncertainty_loss = TimestepUncertaintyLossNetwork().to(accelerator.device)
timestep_uncertainty_loss.load_weights(args.sigma_uncertainty_model)
if args.train_sigma_uncertainty:
timestep_uncertainty_loss.train()
# important that you don't have weight decay here, this model has forced weight norm and its weights will not grow in magnitude - drhead
trainable_params.append({"params": timestep_uncertainty_loss.parameters(), "lr": 1e-3, "weight_decay": 0.0})
else:
timestep_uncertainty_loss.eval()

optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
Expand Down Expand Up @@ -812,6 +823,8 @@ def remove_model(old_ckpt_name):
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)

sigmas = ((1 - noise_scheduler.alphas_cumprod) / noise_scheduler.alphas_cumprod).sqrt().to(accelerator.device)

# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
Expand Down Expand Up @@ -915,6 +928,9 @@ def remove_model(old_ckpt_name):
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)

if args.sigma_uncertainty_model:
loss = timestep_uncertainty_loss.loss(sigmas[timesteps], loss)

loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)

Expand Down Expand Up @@ -988,6 +1004,9 @@ def remove_model(old_ckpt_name):
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)

if args.train_sigma_uncertainty:
accelerator.unwrap_model(timestep_uncertainty_loss).save_weights(args.sigma_uncertainty_model)

remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
remove_ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, remove_epoch_no)
Expand Down

0 comments on commit 0afee96

Please sign in to comment.