Skip to content

Commit

Permalink
Merge pull request bmaltais#475 from kohya-ss/dev
Browse files Browse the repository at this point in the history
unet config, lion 8bit, ddp, pyramid noise etc.
  • Loading branch information
kohya-ss authored May 3, 2023
2 parents bc803e0 + b271a6b commit 2fad5b8
Show file tree
Hide file tree
Showing 14 changed files with 182 additions and 57 deletions.
10 changes: 10 additions & 0 deletions README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@ accelerate configの質問には以下のように答えてください。(bf1

他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。

### オプション:Lion8bitを使う

Lion8bitを使う場合には`bitsandbytes`を0.38.0以降にアップグレードする必要があります。`bitsandbytes`をアンインストールし、Windows環境では例えば[こちら](https://github.com/jllllll/bitsandbytes-windows-webui)などからWindows版のwhlファイルをインストールしてください。たとえば以下のような手順になります。

```powershell
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
```

アップグレード時には`pip install .`でこのリポジトリを更新し、必要に応じて他のパッケージもアップグレードしてください。

## アップグレード

新しいリリースがあった場合、以下のコマンドで更新できます。
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
Other versions of PyTorch and xformers seem to have problems with training.
If there is no other reason, please install the specified version.

### Optional: Use Lion8bit

For Lion8bit, you need to upgrade `bitsandbytes` to 0.38.0 or later. Uninstall `bitsandbytes`, and for Windows, install the Windows version whl file from [here](https://github.com/jllllll/bitsandbytes-windows-webui) or other sources, like:

```powershell
pip install https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.38.1-py3-none-any.whl
```

For upgrading, upgrade this repo with `pip install .`, and upgrade necessary packages manually.

## Upgrade

When a new release comes out you can upgrade your repo with the following command:
Expand Down Expand Up @@ -128,6 +138,28 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

## Change History

### 3 May 2023, 2023/05/03

- When saving v2 models in Diffusers format in training scripts and conversion scripts, it was found that the U-Net configuration is different from those of Hugging Face's stabilityai models (this repository is `"use_linear_projection": false`, stabilityai is `true`). Please note that the weight shapes are different, so please be careful when using the weight files directly. We apologize for the inconvenience.
- Since the U-Net model is created based on the configuration, it should not cause any problems in training or inference.
- Added `--unet_use_linear_projection` option to `convert_diffusers20_original_sd.py` script. If you specify this option, you can save a Diffusers format model with the same configuration as stabilityai's model from an SD format model (a single `*.safetensors` or `*.ckpt` file). Unfortunately, it is not possible to convert a Diffusers format model to the same format.

- Lion8bit optimizer is supported. [PR #447](https://github.com/kohya-ss/sd-scripts/pull/447) Thanks to sdbds!
- Currently it is optional because you need to update `bitsandbytes` version. See "Optional: Use Lion8bit" in installation instructions to use it.
- Multi-GPU training with DDP is supported in each training script. [PR #448](https://github.com/kohya-ss/sd-scripts/pull/448) Thanks to Isotr0py!
- Multi resolution noise (pyramid noise) is supported in each training script. [PR #471](https://github.com/kohya-ss/sd-scripts/pull/471) Thanks to pamparamm!
- See PR and this page [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) for details.

- 学習スクリプトや変換スクリプトでDiffusers形式でv2モデルを保存するとき、U-Netの設定がHugging Faceのstabilityaiのモデルと異なることがわかりました(当リポジトリでは `"use_linear_projection": false`、stabilityaiは`true`)。重みの形状が異なるため、直接重みファイルを利用する場合にはご注意ください。ご不便をお掛けし申し訳ありません。
- U-Netのモデルは設定に基づいて作成されるため、通常、学習や推論で問題になることはないと思われます。
- `convert_diffusers20_original_sd.py`スクリプトに`--unet_use_linear_projection`オプションを追加しました。これを指定するとSD形式のモデル(単一の`*.safetensors`または`*.ckpt`ファイル)から、stabilityaiのモデルと同じ形状の重みファイルを持つDiffusers形式モデルが保存できます。なお、Diffusers形式のモデルを同形式に変換することはできません。

- Lion8bitオプティマイザがサポートされました。[PR #447](https://github.com/kohya-ss/sd-scripts/pull/447) sdbds氏に感謝します。
- `bitsandbytes`のバージョンを更新する必要があるため、現在はオプションです。使用するにはインストール手順の「[オプション:Lion8bitを使う](./README-ja.md#オプションlion8bitを使う)」を参照してください。
- 各学習スクリプトでDDPによるマルチGPU学習がサポートされました。[PR #448](https://github.com/kohya-ss/sd-scripts/pull/448) Isotr0py氏に感謝します。
- Multi resolution noise (pyramid noise) が各学習スクリプトでサポートされました。[PR #471](https://github.com/kohya-ss/sd-scripts/pull/471) pamparamm氏に感謝します。
- 詳細はPRおよびこちらのページ [Multi-Resolution Noise for Diffusion Model Training](https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2) を参照してください。

### 30 Apr. 2023, 2023/04/30

- Added Chinese translation of [DreamBooth guide](./train_db_README-zh.md) and [LoRA guide](./train_network_README-zh.md). [PR #459](https://github.com/kohya-ss/sd-scripts/pull/459) Thanks to tomj2ee!
Expand Down
9 changes: 7 additions & 2 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings, pyramid_noise_like


def train(args):
Expand Down Expand Up @@ -90,7 +90,7 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype)
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)

# verify load/save model formats
if load_stable_diffusion_format:
Expand Down Expand Up @@ -228,6 +228,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
Expand Down Expand Up @@ -304,6 +307,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
elif args.multires_noise_iterations:
noise = pyramid_noise_like(noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount)

# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
Expand Down
13 changes: 13 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import argparse
import random
import re
from typing import List, Optional, Union

Expand Down Expand Up @@ -342,3 +343,15 @@ def get_weighted_text_embeddings(
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)

return text_embeddings


# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
def pyramid_noise_like(noise, device, iterations=6, discount=0.3):
b, c, w, h = noise.shape
u = torch.nn.Upsample(size=(w, h), mode='bilinear').to(device)
for i in range(iterations):
r = random.random()*2+2 # Rather than always going 2x,
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
noise += u(torch.randn(b, c, w, h).to(device)) * discount**i
if w==1 or h==1: break # Lowest resolution is 1x1
return noise/noise.std() # Scaled back to roughly unit variance
3 changes: 1 addition & 2 deletions library/huggingface_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import *
from typing import Union, BinaryIO
from huggingface_hub import HfApi
from pathlib import Path
import argparse
import os

from library.utils import fire_in_thread


Expand Down
16 changes: 11 additions & 5 deletions library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
UNET_PARAMS_NUM_RES_BLOCKS = 2
UNET_PARAMS_CONTEXT_DIM = 768
UNET_PARAMS_NUM_HEADS = 8
# UNET_PARAMS_USE_LINEAR_PROJECTION = False

VAE_PARAMS_Z_CHANNELS = 4
VAE_PARAMS_RESOLUTION = 256
Expand All @@ -34,6 +35,7 @@
# V2
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
V2_UNET_PARAMS_CONTEXT_DIM = 1024
# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True

# Diffusersの設定を読み込むための参照モデル
DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5"
Expand Down Expand Up @@ -357,8 +359,9 @@ def convert_ldm_unet_checkpoint(v2, checkpoint, config):

new_checkpoint[new_path] = unet_state_dict[old_path]

# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
if v2:
# SDのv2では1*1のconv2dがlinearに変わっている
# 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要
if v2 and not config.get('use_linear_projection', False):
linear_transformer_to_conv(new_checkpoint)

return new_checkpoint
Expand Down Expand Up @@ -468,7 +471,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint


def create_unet_diffusers_config(v2):
def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
Expand Down Expand Up @@ -500,7 +503,10 @@ def create_unet_diffusers_config(v2):
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
# use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION,
)
if v2 and use_linear_projection_in_v2:
config["use_linear_projection"] = True

return config

Expand Down Expand Up @@ -846,11 +852,11 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"):


# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None):
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device="cpu", dtype=None, unet_use_linear_projection_in_v2=False):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device)

# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(v2)
unet_config = create_unet_diffusers_config(v2, unet_use_linear_projection_in_v2)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)

unet = UNet2DConditionModel(**unet_config).to(device)
Expand Down
66 changes: 63 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Union,
)
from accelerate import Accelerator
import gc
import glob
import math
import os
Expand All @@ -30,6 +31,7 @@

from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from transformers import CLIPTokenizer
Expand Down Expand Up @@ -1883,7 +1885,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, Lion8bit,SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor",
)

# backward compatibility
Expand Down Expand Up @@ -2119,6 +2121,18 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable noise offset with this value (if enabled, around 0.1 is recommended) / Noise offsetを有効にしてこの値を設定する(有効にする場合は0.1程度を推奨)",
)
parser.add_argument(
"--multires_noise_iterations",
type=int,
default=None,
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)"
)
parser.add_argument(
"--multires_noise_discount",
type=float,
default=0.3,
help="set discount value for multires noise (has no effect without --multires_noise_iterations) / Multires noiseのdiscount値を設定する(--multires_noise_iterations指定時のみ有効)",
)
parser.add_argument(
"--lowram",
action="store_true",
Expand Down Expand Up @@ -2448,7 +2462,7 @@ def task():


def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, Lion8bit, SGDNesterov, SGDNesterov8bit, DAdaptation, Adafactor"

optimizer_type = args.optimizer_type
if args.use_8bit_adam:
Expand Down Expand Up @@ -2526,6 +2540,22 @@ def get_optimizer(args, trainable_params):
optimizer_class = lion_pytorch.Lion
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "Lion8bit".lower():
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")

print(f"use 8-bit Lion optimizer | {optimizer_kwargs}")
try:
optimizer_class = bnb.optim.Lion8bit
except AttributeError:
raise AttributeError(
"No Lion8bit. The version of bitsandbytes installed seems to be old. Please install 0.38.0 or later. / Lion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.38.0以上をインストールしてください"
)

optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type == "SGDNesterov".lower():
print(f"use SGD with Nesterov optimizer | {optimizer_kwargs}")
if "momentum" not in optimizer_kwargs:
Expand Down Expand Up @@ -2850,7 +2880,7 @@ def prepare_dtype(args: argparse.Namespace):
return weight_dtype, save_dtype


def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers
Expand Down Expand Up @@ -2879,6 +2909,36 @@ def load_target_model(args: argparse.Namespace, weight_dtype, device="cpu"):
return text_encoder, vae, unet, load_stable_diffusion_format


def transform_if_model_is_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)


def load_target_model(args, weight_dtype, accelerator):
# load models for each process
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
print(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")

text_encoder, vae, unet, load_stable_diffusion_format = _load_target_model(
args, weight_dtype, accelerator.device if args.lowram else "cpu"
)

# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
accelerator.wait_for_everyone()

text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)

return text_encoder, vae, unet, load_stable_diffusion_format


def patch_accelerator_for_fp16_training(accelerator):
org_unscale_grads = accelerator.scaler._unscale_grads_

Expand Down
2 changes: 1 addition & 1 deletion networks/lora_interrogator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def interrogate(args):
print(f"loading SD model: {args.sd_model}")
args.pretrained_model_name_or_path = args.sd_model
args.vae = None
text_encoder, vae, unet, _ = train_util.load_target_model(args,weights_dtype, DEVICE)
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)

print(f"loading LoRA: {args.model}")
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
Expand Down
Loading

0 comments on commit 2fad5b8

Please sign in to comment.