Skip to content

Commit

Permalink
Diff Output Preserv loss for SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 18, 2024
1 parent 2500f5a commit 3cc5b8d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 22 deletions.
17 changes: 7 additions & 10 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,7 @@
from pathlib import Path

# from toolz import curry
from typing import (
List,
Optional,
Sequence,
Tuple,
Union,
)
from typing import Dict, List, Optional, Sequence, Tuple, Union

import toml
import voluptuous
Expand Down Expand Up @@ -78,6 +72,7 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None


@dataclass
Expand Down Expand Up @@ -197,6 +192,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"token_warmup_step": Any(float, int),
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
Expand Down Expand Up @@ -538,9 +534,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
token_warmup_min: {subset.token_warmup_min}
token_warmup_step: {subset.token_warmup_step}
alpha_mask: {subset.alpha_mask}
custom_attributes: {subset.custom_attributes}
"""
),
" ",
Expand Down
17 changes: 16 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def __init__(
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
custom_attributes: Optional[Dict[str, Any]] = None,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
Expand All @@ -419,6 +420,8 @@ def __init__(
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる

self.custom_attributes = custom_attributes if custom_attributes is not None else {}

self.img_count = 0


Expand Down Expand Up @@ -449,6 +452,7 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

Expand All @@ -473,6 +477,7 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
)

self.is_reg = is_reg
Expand Down Expand Up @@ -512,6 +517,7 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"

Expand All @@ -536,6 +542,7 @@ def __init__(
caption_suffix,
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
)

self.metadata_file = metadata_file
Expand Down Expand Up @@ -1474,11 +1481,14 @@ def __getitem__(self, index):
target_sizes_hw = []
flippeds = [] # 変数名が微妙
text_encoder_outputs_list = []
custom_attributes = []

for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]

custom_attributes.append(subset.custom_attributes)

# in case of fine tuning, is_reg is always False
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)

Expand Down Expand Up @@ -1646,7 +1656,9 @@ def none_or_stack_elements(tensors_list, converter):
return None
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]

# set example
example = {}
example["custom_attributes"] = custom_attributes # may be list of empty dict
example["loss_weights"] = torch.FloatTensor(loss_weights)
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
Expand Down Expand Up @@ -2630,7 +2642,9 @@ def debug_dataset(train_dataset, show_input_ids=False):
f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}'
)
if "network_multipliers" in example:
print(f"network multiplier: {example['network_multipliers'][j]}")
logger.info(f"network multiplier: {example['network_multipliers'][j]}")
if "custom_attributes" in example:
logger.info(f"custom attributes: {example['custom_attributes'][j]}")

# if show_input_ids:
# logger.info(f"input ids: {iid}")
Expand Down Expand Up @@ -4091,6 +4105,7 @@ def enable_high_vram(args: argparse.Namespace):
global HIGH_VRAM
HIGH_VRAM = True


def verify_training_args(args: argparse.Namespace):
r"""
Verify training arguments. Also reflect highvram option to global variable
Expand Down
20 changes: 19 additions & 1 deletion sdxl_train_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import List, Optional

import torch
from accelerate import Accelerator
Expand Down Expand Up @@ -172,7 +173,18 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei

return encoder_hidden_states1, encoder_hidden_states2, pool2

def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
def call_unet(
self,
args,
accelerator,
unet,
noisy_latents,
timesteps,
text_conds,
batch,
weight_dtype,
indices: Optional[List[int]] = None,
):
noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

# get size embeddings
Expand All @@ -186,6 +198,12 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond
vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)

if indices is not None and len(indices) > 0:
noisy_latents = noisy_latents[indices]
timesteps = timesteps[indices]
text_embedding = text_embedding[indices]
vector_embedding = vector_embedding[indices]

noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
return noise_pred

Expand Down
35 changes: 25 additions & 10 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, tex
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)

def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
return noise_pred

Expand Down Expand Up @@ -218,6 +218,30 @@ def get_noise_pred_and_target(
else:
target = noise

# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)

if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
noise_pred_prior = self.call_unet(
args,
accelerator,
unet,
noisy_latents,
timesteps,
text_encoder_conds,
batch,
weight_dtype,
indices=diff_output_pr_indices,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)

return noise_pred, target, timesteps, huber_c, None

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
Expand Down Expand Up @@ -1123,15 +1147,6 @@ def remove_model(old_ckpt_name):
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
# # SD only
# encoded_text_encoder_conds = get_weighted_text_embeddings(
# tokenizers[0],
# text_encoder,
# batch["captions"],
# accelerator.device,
# args.max_token_length // 75 if args.max_token_length else 1,
# clip_skip=args.clip_skip,
# )
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
Expand Down

0 comments on commit 3cc5b8d

Please sign in to comment.