From 3cc5b8db99c66b9e205c4fd4a5f969090c51ef58 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 18 Oct 2024 20:57:13 +0900 Subject: [PATCH 1/3] Diff Output Preserv loss for SDXL --- library/config_util.py | 17 +++++++---------- library/train_util.py | 17 ++++++++++++++++- sdxl_train_network.py | 20 +++++++++++++++++++- train_network.py | 35 +++++++++++++++++++++++++---------- 4 files changed, 67 insertions(+), 22 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index f8cdfe60a..fc1fbf46d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -10,13 +10,7 @@ from pathlib import Path # from toolz import curry -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Dict, List, Optional, Sequence, Tuple, Union import toml import voluptuous @@ -78,6 +72,7 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None @dataclass @@ -197,6 +192,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "token_warmup_step": Any(float, int), "caption_prefix": str, "caption_suffix": str, + "custom_attributes": dict, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -538,9 +534,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask}, + token_warmup_min: {subset.token_warmup_min} + token_warmup_step: {subset.token_warmup_step} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """ ), " ", diff --git a/library/train_util.py b/library/train_util.py index 4a446e81c..7d3fce5b2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -396,6 +396,7 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -419,6 +420,8 @@ def __init__( self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.custom_attributes = custom_attributes if custom_attributes is not None else {} + self.img_count = 0 @@ -449,6 +452,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -473,6 +477,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.is_reg = is_reg @@ -512,6 +517,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -536,6 +542,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.metadata_file = metadata_file @@ -1474,11 +1481,14 @@ def __getitem__(self, index): target_sizes_hw = [] flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] + custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] + custom_attributes.append(subset.custom_attributes) + # in case of fine tuning, is_reg is always False loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) @@ -1646,7 +1656,9 @@ def none_or_stack_elements(tensors_list, converter): return None return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + # set example example = {} + example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) @@ -2630,7 +2642,9 @@ def debug_dataset(train_dataset, show_input_ids=False): f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: - print(f"network multiplier: {example['network_multipliers'][j]}") + logger.info(f"network multiplier: {example['network_multipliers'][j]}") + if "custom_attributes" in example: + logger.info(f"custom attributes: {example['custom_attributes'][j]}") # if show_input_ids: # logger.info(f"input ids: {iid}") @@ -4091,6 +4105,7 @@ def enable_high_vram(args: argparse.Namespace): global HIGH_VRAM HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4a16a4891..d45df6e05 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,4 +1,5 @@ import argparse +from typing import List, Optional import torch from accelerate import Accelerator @@ -172,7 +173,18 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei return encoder_hidden_states1, encoder_hidden_states2, pool2 - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype # get size embeddings @@ -186,6 +198,12 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred diff --git a/train_network.py b/train_network.py index d5330aef4..ef766737d 100644 --- a/train_network.py +++ b/train_network.py @@ -143,7 +143,7 @@ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, tex for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -218,6 +218,30 @@ def get_noise_pred_and_target( else: target = noise + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + return noise_pred, target, timesteps, huber_c, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): @@ -1123,15 +1147,6 @@ def remove_model(old_ckpt_name): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # # SD only - # encoded_text_encoder_conds = get_weighted_text_embeddings( - # tokenizers[0], - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy, From ef70aa7b42b5c923cc1a8594b2f30487a2b4f700 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Fri, 18 Oct 2024 23:39:48 +0900 Subject: [PATCH 2/3] add FLUX.1 support --- README.md | 19 +++++++ flux_train_network.py | 123 ++++++++++++++++++++++++++++-------------- 2 files changed, 103 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 7fae50d1a..59f70ebcd 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,25 @@ The command to install PyTorch is as follows: ### Recent Updates +Oct 19, 2024: + +- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. + - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. + - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. + - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. + - Specify "number of training images x number of epochs >= number of regularization images x number of epochs". + - Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000. + - Set the loss in the training without using the regularization image to be close to the loss in the training using DOP. +``` +[[datasets.subsets]] +image_dir = "path/to/image/dir" +num_repeats = 1 +is_reg = true +custom_attributes.diff_output_preservation = true # Add this +``` + + + Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. diff --git a/flux_train_network.py b/flux_train_network.py index aa92fe3ae..8431a6dc9 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -373,33 +373,13 @@ def get_noise_pred_and_target( if not args.apply_t5_attn_mask: t5_attn_mask = None - if not args.split_mode: - # normal forward - with accelerator.autocast(): - # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) - model_pred = unet( - img=packed_noisy_model_input, - img_ids=img_ids, - txt=t5_out, - txt_ids=txt_ids, - y=l_pooled, - timesteps=timesteps / 1000, - guidance=guidance_vec, - txt_attention_mask=t5_attn_mask, - ) - else: - # split forward to reduce memory usage - assert network.train_blocks == "single", "train_blocks must be single for split mode" - with accelerator.autocast(): - # move flux lower to cpu, and then move flux upper to gpu - unet.to("cpu") - clean_memory_on_device(accelerator.device) - self.flux_upper.to(accelerator.device) - - # upper model does not require grad - with torch.no_grad(): - intermediate_img, intermediate_txt, vec, pe = self.flux_upper( - img=packed_noisy_model_input, + def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): + if not args.split_mode: + # normal forward + with accelerator.autocast(): + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + model_pred = unet( + img=img, img_ids=img_ids, txt=t5_out, txt_ids=txt_ids, @@ -408,18 +388,52 @@ def get_noise_pred_and_target( guidance=guidance_vec, txt_attention_mask=t5_attn_mask, ) - - # move flux upper back to cpu, and then move flux lower to gpu - self.flux_upper.to("cpu") - clean_memory_on_device(accelerator.device) - unet.to(accelerator.device) - - # lower model requires grad - intermediate_img.requires_grad_(True) - intermediate_txt.requires_grad_(True) - vec.requires_grad_(True) - pe.requires_grad_(True) - model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + else: + # split forward to reduce memory usage + assert network.train_blocks == "single", "train_blocks must be single for split mode" + with accelerator.autocast(): + # move flux lower to cpu, and then move flux upper to gpu + unet.to("cpu") + clean_memory_on_device(accelerator.device) + self.flux_upper.to(accelerator.device) + + # upper model does not require grad + with torch.no_grad(): + intermediate_img, intermediate_txt, vec, pe = self.flux_upper( + img=packed_noisy_model_input, + img_ids=img_ids, + txt=t5_out, + txt_ids=txt_ids, + y=l_pooled, + timesteps=timesteps / 1000, + guidance=guidance_vec, + txt_attention_mask=t5_attn_mask, + ) + + # move flux upper back to cpu, and then move flux lower to gpu + self.flux_upper.to("cpu") + clean_memory_on_device(accelerator.device) + unet.to(accelerator.device) + + # lower model requires grad + intermediate_img.requires_grad_(True) + intermediate_txt.requires_grad_(True) + vec.requires_grad_(True) + pe.requires_grad_(True) + model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask) + + return model_pred + + model_pred = call_dit( + img=packed_noisy_model_input, + img_ids=img_ids, + t5_out=t5_out, + txt_ids=txt_ids, + l_pooled=l_pooled, + timesteps=timesteps, + guidance_vec=guidance_vec, + t5_attn_mask=t5_attn_mask, + ) # unpack latents model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width) @@ -430,6 +444,37 @@ def get_noise_pred_and_target( # flow matching loss: this is different from SD3 target = noise - latents + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + model_pred_prior = call_dit( + img=packed_noisy_model_input[diff_output_pr_indices], + img_ids=img_ids[diff_output_pr_indices], + t5_out=t5_out[diff_output_pr_indices], + txt_ids=txt_ids[diff_output_pr_indices], + l_pooled=l_pooled[diff_output_pr_indices], + timesteps=timesteps[diff_output_pr_indices], + guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None, + t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + + model_pred_prior = flux_utils.unpack_latents(model_pred_prior, packed_latent_height, packed_latent_width) + model_pred_prior, _ = flux_train_utils.apply_model_prediction_type( + args, + model_pred_prior, + noisy_model_input[diff_output_pr_indices], + sigmas[diff_output_pr_indices] if sigmas is not None else None, + ) + target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) + return model_pred, target, timesteps, None, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): From 2c45d979e696fd4412ae1336feaee3bc9b967af4 Mon Sep 17 00:00:00 2001 From: kohya-ss Date: Sat, 19 Oct 2024 19:21:12 +0900 Subject: [PATCH 3/3] update README, remove unnecessary autocast --- README.md | 10 ++++------ flux_train_network.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 59f70ebcd..32ee38573 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,13 @@ The command to install PyTorch is as follows: Oct 19, 2024: -- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. +- Added an implementation of Differential Output Preservation (temporary name) for SDXL/FLUX.1 LoRA training. SD1/2 is not tested yet. This is an experimental feature. - A method to make the output of LoRA closer to the output when LoRA is not applied, with captions that do not contain trigger words. - Define a Dataset subset for the regularization image (`is_reg = true`) with `.toml`. Add `custom_attributes.diff_output_preservation = true`. - See [dataset configuration](docs/config_README-en.md) for the regularization dataset. - - Specify "number of training images x number of epochs >= number of regularization images x number of epochs". - - Specify a large value for `--prior_loss_weight` option (not dataset config). We recommend 10-1000. - - Set the loss in the training without using the regularization image to be close to the loss in the training using DOP. + - Specify "number of training images x number of repeats >= number of regularization images x number of repeats". + - Specify a large value for `--prior_loss_weight` option (not dataset config). The appropriate value is unknown, but try around 10-100. Note that the default is 1.0. + - You may want to start with 2/3 to 3/4 of the loss value when DOP is not applied. If it is 1/2, DOP may not be working. ``` [[datasets.subsets]] image_dir = "path/to/image/dir" @@ -28,8 +28,6 @@ is_reg = true custom_attributes.diff_output_preservation = true # Add this ``` - - Oct 13, 2024: - Fixed an issue where it took a long time to load the image size when initializing the dataset, especially when the number of images in the dataset was large. diff --git a/flux_train_network.py b/flux_train_network.py index 8431a6dc9..9cc8811b5 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -453,7 +453,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t if len(diff_output_pr_indices) > 0: network.set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): + with torch.no_grad(): model_pred_prior = call_dit( img=packed_noisy_model_input[diff_output_pr_indices], img_ids=img_ids[diff_output_pr_indices],