diff --git a/library/config_util.py b/library/config_util.py index f8cdfe60a..fc1fbf46d 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -10,13 +10,7 @@ from pathlib import Path # from toolz import curry -from typing import ( - List, - Optional, - Sequence, - Tuple, - Union, -) +from typing import Dict, List, Optional, Sequence, Tuple, Union import toml import voluptuous @@ -78,6 +72,7 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + custom_attributes: Optional[Dict[str, Any]] = None @dataclass @@ -197,6 +192,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "token_warmup_step": Any(float, int), "caption_prefix": str, "caption_suffix": str, + "custom_attributes": dict, } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -538,9 +534,10 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} - token_warmup_min: {subset.token_warmup_min}, - token_warmup_step: {subset.token_warmup_step}, - alpha_mask: {subset.alpha_mask}, + token_warmup_min: {subset.token_warmup_min} + token_warmup_step: {subset.token_warmup_step} + alpha_mask: {subset.alpha_mask} + custom_attributes: {subset.custom_attributes} """ ), " ", diff --git a/library/train_util.py b/library/train_util.py index 4a446e81c..7d3fce5b2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -396,6 +396,7 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: self.image_dir = image_dir self.alpha_mask = alpha_mask if alpha_mask is not None else False @@ -419,6 +420,8 @@ def __init__( self.token_warmup_min = token_warmup_min # step=0におけるタグの数 self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.custom_attributes = custom_attributes if custom_attributes is not None else {} + self.img_count = 0 @@ -449,6 +452,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -473,6 +477,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.is_reg = is_reg @@ -512,6 +517,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes: Optional[Dict[str, Any]] = None, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -536,6 +542,7 @@ def __init__( caption_suffix, token_warmup_min, token_warmup_step, + custom_attributes=custom_attributes, ) self.metadata_file = metadata_file @@ -1474,11 +1481,14 @@ def __getitem__(self, index): target_sizes_hw = [] flippeds = [] # 変数名が微妙 text_encoder_outputs_list = [] + custom_attributes = [] for image_key in bucket[image_index : image_index + bucket_batch_size]: image_info = self.image_data[image_key] subset = self.image_to_subset[image_key] + custom_attributes.append(subset.custom_attributes) + # in case of fine tuning, is_reg is always False loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) @@ -1646,7 +1656,9 @@ def none_or_stack_elements(tensors_list, converter): return None return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))] + # set example example = {} + example["custom_attributes"] = custom_attributes # may be list of empty dict example["loss_weights"] = torch.FloatTensor(loss_weights) example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor) example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x) @@ -2630,7 +2642,9 @@ def debug_dataset(train_dataset, show_input_ids=False): f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}", original size: {orgsz}, crop top left: {crptl}, target size: {trgsz}, flipped: {flpdz}' ) if "network_multipliers" in example: - print(f"network multiplier: {example['network_multipliers'][j]}") + logger.info(f"network multiplier: {example['network_multipliers'][j]}") + if "custom_attributes" in example: + logger.info(f"custom attributes: {example['custom_attributes'][j]}") # if show_input_ids: # logger.info(f"input ids: {iid}") @@ -4091,6 +4105,7 @@ def enable_high_vram(args: argparse.Namespace): global HIGH_VRAM HIGH_VRAM = True + def verify_training_args(args: argparse.Namespace): r""" Verify training arguments. Also reflect highvram option to global variable diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 4a16a4891..d45df6e05 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,4 +1,5 @@ import argparse +from typing import List, Optional import torch from accelerate import Accelerator @@ -172,7 +173,18 @@ def get_text_cond(self, args, accelerator, batch, tokenizers, text_encoders, wei return encoder_hidden_states1, encoder_hidden_states2, pool2 - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet( + self, + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_conds, + batch, + weight_dtype, + indices: Optional[List[int]] = None, + ): noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype # get size embeddings @@ -186,6 +198,12 @@ def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_cond vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype) text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype) + if indices is not None and len(indices) > 0: + noisy_latents = noisy_latents[indices] + timesteps = timesteps[indices] + text_embedding = text_embedding[indices] + vector_embedding = vector_embedding[indices] + noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) return noise_pred diff --git a/train_network.py b/train_network.py index d5330aef4..ef766737d 100644 --- a/train_network.py +++ b/train_network.py @@ -143,7 +143,7 @@ def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, tex for t_enc in text_encoders: t_enc.to(accelerator.device, dtype=weight_dtype) - def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype): + def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample return noise_pred @@ -218,6 +218,30 @@ def get_noise_pred_and_target( else: target = noise + # differential output preservation + if "custom_attributes" in batch: + diff_output_pr_indices = [] + for i, custom_attributes in enumerate(batch["custom_attributes"]): + if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]: + diff_output_pr_indices.append(i) + + if len(diff_output_pr_indices) > 0: + network.set_multiplier(0.0) + with torch.no_grad(), accelerator.autocast(): + noise_pred_prior = self.call_unet( + args, + accelerator, + unet, + noisy_latents, + timesteps, + text_encoder_conds, + batch, + weight_dtype, + indices=diff_output_pr_indices, + ) + network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step + target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) + return noise_pred, target, timesteps, huber_c, None def post_process_loss(self, loss, args, timesteps, noise_scheduler): @@ -1123,15 +1147,6 @@ def remove_model(old_ckpt_name): with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: - # # SD only - # encoded_text_encoder_conds = get_weighted_text_embeddings( - # tokenizers[0], - # text_encoder, - # batch["captions"], - # accelerator.device, - # args.max_token_length // 75 if args.max_token_length else 1, - # clip_skip=args.clip_skip, - # ) input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"]) encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights( tokenize_strategy,