Skip to content

Commit

Permalink
* 2023/02/16 (v20.7.3)
Browse files Browse the repository at this point in the history
    - Noise offset is recorded to the metadata. Thanks to space-nuko!
    - Show the moving average loss to prevent loss jumping in `train_network.py` and `train_db.py`. Thanks to shirayu!
  • Loading branch information
bmaltais committed Feb 18, 2023
1 parent f9863e3 commit 674ed88
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 62 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ Then redo the installation instruction within the kohya_ss venv.

## Change history

* 2023/02/16 (v20.7.3)
- Noise offset is recorded to the metadata. Thanks to space-nuko!
- Show the moving average loss to prevent loss jumping in `train_network.py` and `train_db.py`. Thanks to shirayu!
* 2023/02/11 (v20.7.2):
- `lora_interrogator.py` is added in `networks` folder. See `python networks\lora_interrogator.py -h` for usage.
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
Expand Down
40 changes: 18 additions & 22 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import library.train_util as train_util
from library.train_util import DreamBoothDataset

import torch.optim as optim
import dadaptation

def collate_fn(examples):
return examples[0]
Expand Down Expand Up @@ -135,16 +133,13 @@ def train(args):
trainable_params = unet.parameters()

# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
# optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers)
print('enable dadatation.')
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0, d0=0.00000001)


# 学習ステップ数を計算する
if args.max_train_epochs is not None:
Expand All @@ -155,14 +150,8 @@ def train(args):
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end

# lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler(
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)

# For Adam
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
lr_lambda=[lambda epoch: 1],
last_epoch=-1,
verbose=False)
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)

# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
Expand Down Expand Up @@ -217,6 +206,8 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth")

loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
Expand All @@ -227,7 +218,6 @@ def train(args):
if args.gradient_checkpointing or global_step < args.stop_text_encoder_training:
text_encoder.train()

loss_total = 0
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
Expand All @@ -244,10 +234,13 @@ def train(args):
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)

# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
Expand Down Expand Up @@ -299,21 +292,24 @@ def train(args):

current_loss = loss.detach().item()
if args.logging_dir is not None:
# logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
logs = {"loss": current_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']}
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)

if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
logs = {"avg_loss": avr_loss, "dlr": optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']} # , "lr": lr_scheduler.get_last_lr()[0]}
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)

accelerator.wait_for_everyone()
Expand Down
82 changes: 42 additions & 40 deletions train_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Optional, Union
import importlib
import argparse
Expand All @@ -19,8 +21,6 @@
import library.train_util as train_util
from library.train_util import DreamBoothDataset, FineTuningDataset

import torch.optim as optim
import dadaptation

def collate_fn(examples):
return examples[0]
Expand Down Expand Up @@ -156,7 +156,9 @@ def train(args):

# モデルを読み込む
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)

# unnecessary, but work on low-ram device
text_encoder.to("cuda")
unet.to("cuda")
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)

Expand Down Expand Up @@ -214,15 +216,10 @@ def train(args):
else:
optimizer_class = torch.optim.AdamW

# trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
trainable_params = network.prepare_optimizer_params(None, None)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)

# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
# optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
print('enable dadatation.')
optimizer = dadaptation.DAdaptAdam(trainable_params, lr=1.0, decouple=True, weight_decay=0)
# optimizer = dadaptation.DAdaptSGD(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
# optimizer = dadaptation.DAdaptAdaGrad(trainable_params, lr=1.0, weight_decay=0, d0=1e-6)
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)

# dataloaderを準備する
# DataLoaderのプロセス数:0はメインプロセスになる
Expand All @@ -237,23 +234,10 @@ def train(args):

# lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler(
# lr_scheduler = get_scheduler_fix(
# args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
# num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
# num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# override lr_scheduler.

# For Adam
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
lr_lambda=[lambda epoch: 0.25, lambda epoch: 1],
last_epoch=-1,
verbose=False)

# For SGD optim
# lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
# lr_lambda=[lambda epoch: 1, lambda epoch: 0.5],
# last_epoch=-1,
# verbose=False)
lr_scheduler = get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)

# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
Expand All @@ -278,17 +262,26 @@ def train(args):
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device)
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train()
text_encoder.train()

# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
if type(text_encoder) == DDP:
text_encoder.module.text_model.embeddings.requires_grad_(True)
else:
text_encoder.text_model.embeddings.requires_grad_(True)
else:
unet.eval()
text_encoder.eval()

# support DistributedDataParallel
if type(text_encoder) == DDP:
text_encoder = text_encoder.module
unet = unet.module
network = network.module

network.prepare_grad_etc(text_encoder, unet)

if not cache_latents:
Expand Down Expand Up @@ -360,11 +353,13 @@ def train(args):
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
"ss_seed": args.seed,
"ss_keep_tokens": args.keep_tokens,
"ss_noise_offset": args.noise_offset,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training
"ss_training_comment": args.training_comment, # will not be updated after training
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash()
}

# uncomment if another network is added
Expand Down Expand Up @@ -398,6 +393,8 @@ def train(args):
if accelerator.is_main_process:
accelerator.init_trackers("network_train")

loss_list = []
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
Expand All @@ -406,7 +403,6 @@ def train(args):

network.on_epoch_start(text_encoder, unet)

loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network):
with torch.no_grad():
Expand All @@ -425,6 +421,9 @@ def train(args):

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)

# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
Expand All @@ -435,7 +434,8 @@ def train(args):
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
with autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

if args.v_parameterization:
# v-parameterization training
Expand Down Expand Up @@ -466,23 +466,25 @@ def train(args):
global_step += 1

current_loss = loss.detach().item()
if epoch == 0:
loss_list.append(current_loss)
else:
loss_total -= loss_list[step]
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / (step+1)
# logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
# progress_bar.set_postfix(**logs)
logs_str = f"loss: {avr_loss:.3f}, dlr0: {optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']:.2e}, dlr1: {optimizer.param_groups[1]['d']*optimizer.param_groups[1]['lr']:.2e}"
progress_bar.set_postfix_str(logs_str)
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
logs['lr/d*lr'] = optimizer.param_groups[0]['d']*optimizer.param_groups[0]['lr']
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
break

if args.logging_dir is not None:
logs = {"loss/epoch": loss_total / len(train_dataloader)}
logs = {"loss/epoch": loss_total / len(loss_list)}
accelerator.log(logs, step=epoch+1)

accelerator.wait_for_everyone()
Expand Down Expand Up @@ -568,4 +570,4 @@ def remove_old_func(old_epoch_no):
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")

args = parser.parse_args()
train(args)
train(args)

0 comments on commit 674ed88

Please sign in to comment.