Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masked loss #1207

Merged
merged 11 commits into from
Mar 26, 2024
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

## Change History

### Masked loss

`train_network.py`, `sdxl_train_network.py` and `sdxl_train.py` now support the masked loss. `--masked_loss` option is added.

NOTE: `train_network.py` and `sdxl_train.py` are not tested yet.

ControlNet dataset is used to specify the mask. The mask images should be the RGB images. The pixel value 255 in R channel is treated as the mask (the loss is calculated only for the pixels with the mask), and 0 is treated as the non-mask. See details for the dataset specification in the [LLLite documentation](./docs/train_lllite_README.md#preparing-the-dataset).


### Working in progress

- Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging.
Expand Down
8 changes: 6 additions & 2 deletions docs/train_lllite_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ ComfyUIのカスタムノードを用意しています。: https://github.com/k
## モデルの学習

### データセットの準備
通常のdatasetに加え、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です
DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにconditioning imageを格納してください。

たとえば DreamBooth 方式でキャプションファイルを用いる場合の設定ファイルは以下のようになります。
(finetuning 方式の dataset はサポートしていません。)

conditioning imageは学習用画像と同じbasenameを持つ必要があります。また、conditioning imageは学習用画像と同じサイズに自動的にリサイズされます。conditioning imageにはキャプションファイルは不要です。

たとえば、キャプションにフォルダ名ではなくキャプションファイルを用いる場合の設定ファイルは以下のようになります。

```toml
[[datasets.subsets]]
Expand Down
4 changes: 3 additions & 1 deletion docs/train_lllite_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ Due to the limitations of the inference environment, only CrossAttention (attn1

### Preparing the dataset

In addition to the normal dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.
In addition to the normal DreamBooth method dataset, please store the conditioning image in the directory specified by `conditioning_data_dir`. The conditioning image must have the same basename as the training image. The conditioning image will be automatically resized to the same size as the training image. The conditioning image does not require a caption file.

(We do not support the finetuning method dataset.)

```toml
[[datasets.subsets]]
Expand Down
12 changes: 8 additions & 4 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,10 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
}

def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
assert (
support_dreambooth or support_finetuning or support_controlnet
), "Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more. / DreamBooth モードか fine tuning モードのどちらも指定されていません。1つ以上指定してください。"
assert support_dreambooth or support_finetuning or support_controlnet, (
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
)

self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
Expand Down Expand Up @@ -322,7 +323,10 @@ def validate_flex_dataset(dataset_config: dict):

self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
if support_controlnet:
self.dataset_schema = self.cn_dataset_schema
else:
self.dataset_schema = self.db_dataset_schema
elif support_finetuning:
self.dataset_schema = self.ft_dataset_schema
elif support_controlnet:
Expand Down
24 changes: 20 additions & 4 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import random
import re
from typing import List, Optional, Union
from .utils import setup_logging
from .utils import setup_logging

setup_logging()
import logging
import logging

logger = logging.getLogger(__name__)


def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
Expand Down Expand Up @@ -64,7 +67,7 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr+1).float().to(loss.device)
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
else:
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
loss = loss * snr_weight
Expand Down Expand Up @@ -92,13 +95,15 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los
loss = loss + loss / scale * v_pred_like_loss
return loss


def apply_debiased_estimation(loss, timesteps, noise_scheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
weight = 1/torch.sqrt(snr_t)
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss


# TODO train_utilと分散しているのでどちらかに寄せる


Expand Down Expand Up @@ -474,6 +479,17 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def apply_masked_loss(loss, batch):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel

# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss


"""
##########################################
# Perlin Noise
Expand Down
45 changes: 35 additions & 10 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,6 +1835,9 @@ def __init__(

db_subsets = []
for subset in subsets:
assert (
not subset.random_crop
), "random_crop is not supported in ControlNetDataset / random_cropはControlNetDatasetではサポートされていません"
db_subset = DreamBoothSubset(
subset.image_dir,
False,
Expand Down Expand Up @@ -1885,7 +1888,7 @@ def __init__(

# assert all conditioning data exists
missing_imgs = []
cond_imgs_with_img = set()
cond_imgs_with_pair = set()
for image_key, info in self.dreambooth_dataset_delegate.image_data.items():
db_subset = self.dreambooth_dataset_delegate.image_to_subset[image_key]
subset = None
Expand All @@ -1899,23 +1902,29 @@ def __init__(
logger.warning(f"not directory: {subset.conditioning_data_dir}")
continue

img_basename = os.path.basename(info.absolute_path)
ctrl_img_path = os.path.join(subset.conditioning_data_dir, img_basename)
if not os.path.exists(ctrl_img_path):
img_basename = os.path.splitext(os.path.basename(info.absolute_path))[0]
ctrl_img_path = glob_images(subset.conditioning_data_dir, img_basename)
if len(ctrl_img_path) < 1:
missing_imgs.append(img_basename)
continue
ctrl_img_path = ctrl_img_path[0]
ctrl_img_path = os.path.abspath(ctrl_img_path) # normalize path

info.cond_img_path = ctrl_img_path
cond_imgs_with_img.add(ctrl_img_path)
cond_imgs_with_pair.add(os.path.splitext(ctrl_img_path)[0]) # remove extension because Windows is case insensitive

extra_imgs = []
for subset in subsets:
conditioning_img_paths = glob_images(subset.conditioning_data_dir, "*")
extra_imgs.extend(
[cond_img_path for cond_img_path in conditioning_img_paths if cond_img_path not in cond_imgs_with_img]
)
conditioning_img_paths = [os.path.abspath(p) for p in conditioning_img_paths] # normalize path
extra_imgs.extend([p for p in conditioning_img_paths if os.path.splitext(p)[0] not in cond_imgs_with_pair])

assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
assert (
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
assert (
len(extra_imgs) == 0
), f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"

self.conditioning_image_transforms = IMAGE_TRANSFORMS

Expand Down Expand Up @@ -3049,6 +3058,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument("--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う")

parser.add_argument(
"--ddp_timeout",
type=int,
Expand Down Expand Up @@ -3111,6 +3121,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
)

parser.add_argument(
"--noise_offset",
type=float,
Expand Down Expand Up @@ -3284,6 +3295,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)


def add_masked_loss_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--conditioning_data_dir",
type=str,
default=None,
help="conditioning data directory / 条件付けデータのディレクトリ",
)
parser.add_argument(
"--masked_loss",
action="store_true",
help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要",
)


def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
Expand Down
9 changes: 7 additions & 2 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from library.device_utils import init_ipex, clean_memory_on_device


init_ipex()

from accelerate.utils import set_seed
Expand Down Expand Up @@ -40,6 +41,7 @@
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel

Expand Down Expand Up @@ -126,7 +128,7 @@ def train(args):

# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
Expand Down Expand Up @@ -595,9 +597,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
or args.scale_v_pred_loss_like_noise_pred
or args.v_pred_like_loss
or args.debiased_estimation_loss
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand Down Expand Up @@ -763,6 +768,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
Expand Down Expand Up @@ -799,7 +805,6 @@ def setup_parser() -> argparse.ArgumentParser:
help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / "
+ f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値",
)

return parser


Expand Down
7 changes: 6 additions & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device


init_ipex()

from accelerate.utils import set_seed
Expand All @@ -34,6 +35,7 @@
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments

Expand All @@ -60,7 +62,7 @@ def train(args):

# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, args.masked_loss, True))
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
Expand Down Expand Up @@ -357,6 +359,8 @@ def train(args):
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down Expand Up @@ -482,6 +486,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
Expand Down
7 changes: 6 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from library.device_utils import init_ipex, clean_memory_on_device


init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -40,6 +41,7 @@
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
)
from library.utils import setup_logging, add_logging_arguments

Expand Down Expand Up @@ -159,7 +161,7 @@ def train(self, args):

# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
if use_user_config:
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
Expand Down Expand Up @@ -852,6 +854,8 @@ def remove_model(old_ckpt_name):
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down Expand Up @@ -975,6 +979,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
Expand Down
Loading