diff --git a/.gitignore b/.gitignore index 506bc9f0..c5b1bb4f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ /venv* /debug* /workspace* +/models* /training_concepts /training_samples debug.py diff --git a/modules/dataLoader/MgdsKandinskyBaseDataLoader.py b/modules/dataLoader/MgdsKandinskyBaseDataLoader.py index 1b7c0fd0..825971c0 100644 --- a/modules/dataLoader/MgdsKandinskyBaseDataLoader.py +++ b/modules/dataLoader/MgdsKandinskyBaseDataLoader.py @@ -62,7 +62,7 @@ def _load_input_modules(self, args: TrainArgs, model: KandinskyModel) -> list: select_prompt_input = SelectInput(setting_name='concept.prompt_source', out_name='prompts', setting_to_in_name_map={ 'sample': 'sample_prompts', 'concept': 'concept_prompts', - 'filename': 'filename_prompts', + 'filename': 'filename_prompt', }, default_in_name='sample_prompts') select_random_text = SelectRandomText(texts_in_name='prompts', text_out_name='prompt') diff --git a/modules/dataLoader/MgdsStableDiffusionFineTuneDataLoader.py b/modules/dataLoader/MgdsStableDiffusionFineTuneDataLoader.py index f173e122..e4c0f31d 100644 --- a/modules/dataLoader/MgdsStableDiffusionFineTuneDataLoader.py +++ b/modules/dataLoader/MgdsStableDiffusionFineTuneDataLoader.py @@ -1,5 +1,3 @@ -from mgds.GenericDataLoaderModules import * - from modules.dataLoader.MgdsStableDiffusionBaseDataLoader import MgdsStablDiffusionBaseDataLoader from modules.model.StableDiffusionModel import StableDiffusionModel from modules.util.TrainProgress import TrainProgress diff --git a/modules/dataLoader/MgdsStableDiffusionFineXLTuneDataLoader.py b/modules/dataLoader/MgdsStableDiffusionFineXLTuneDataLoader.py new file mode 100644 index 00000000..2fd68db4 --- /dev/null +++ b/modules/dataLoader/MgdsStableDiffusionFineXLTuneDataLoader.py @@ -0,0 +1,14 @@ +from modules.dataLoader.MgdsStableDiffusionXLBaseDataLoader import MgdsStablDiffusionXLBaseDataLoader +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.util.TrainProgress import TrainProgress +from modules.util.args.TrainArgs import TrainArgs + + +class MgdsStableDiffusionXLFineTuneDataLoader(MgdsStablDiffusionXLBaseDataLoader): + def __init__( + self, + args: TrainArgs, + model: StableDiffusionXLModel, + train_progress: TrainProgress, + ): + super(MgdsStableDiffusionXLFineTuneDataLoader, self).__init__(args, model, train_progress) diff --git a/modules/dataLoader/MgdsStableDiffusionXLBaseDataLoader.py b/modules/dataLoader/MgdsStableDiffusionXLBaseDataLoader.py new file mode 100644 index 00000000..fce19ebd --- /dev/null +++ b/modules/dataLoader/MgdsStableDiffusionXLBaseDataLoader.py @@ -0,0 +1,376 @@ +import json + +from mgds.DebugDataLoaderModules import DecodeVAE, SaveImage, SaveText, DecodeTokens +from mgds.DiffusersDataLoaderModules import * +from mgds.GenericDataLoaderModules import * +from mgds.MGDS import MGDS, TrainDataLoader, OutputPipelineModule +from mgds.TransformersDataLoaderModules import * + +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.util import path_util +from modules.util.TrainProgress import TrainProgress +from modules.util.args.TrainArgs import TrainArgs + + +class MgdsStablDiffusionXLBaseDataLoader: + def __init__( + self, + args: TrainArgs, + model: StableDiffusionXLModel, + train_progress: TrainProgress, + ): + with open(args.concept_file_name, 'r') as f: + concepts = json.load(f) + + self.ds = self.create_dataset( + args=args, + model=model, + concepts=concepts, + train_progress=train_progress, + ) + self.dl = TrainDataLoader(self.ds, args.batch_size) + + + def _enumerate_input_modules(self, args: TrainArgs) -> list: + supported_extensions = path_util.supported_image_extensions() + + collect_paths = CollectPaths( + concept_in_name='concept', path_in_name='path', name_in_name='name', path_out_name='image_path', concept_out_name='concept', + extensions=supported_extensions, include_postfix=None, exclude_postfix=['-masklabel'], include_subdirectories_in_name='concept.include_subdirectories' + ) + + mask_path = ModifyPath(in_name='image_path', out_name='mask_path', postfix='-masklabel', extension='.png') + sample_prompt_path = ModifyPath(in_name='image_path', out_name='sample_prompt_path', postfix='', extension='.txt') + + modules = [collect_paths, sample_prompt_path] + + if args.masked_training: + modules.append(mask_path) + + return modules + + + def _load_input_modules(self, args: TrainArgs, model: StableDiffusionXLModel) -> list: + load_image = LoadImage(path_in_name='image_path', image_out_name='image', range_min=0, range_max=1) + + generate_mask = GenerateImageLike(image_in_name='image', image_out_name='mask', color=255, range_min=0, range_max=1, channels=1) + load_mask = LoadImage(path_in_name='mask_path', image_out_name='mask', range_min=0, range_max=1, channels=1) + + load_sample_prompts = LoadMultipleTexts(path_in_name='sample_prompt_path', texts_out_name='sample_prompts') + load_concept_prompts = LoadMultipleTexts(path_in_name='concept.prompt_path', texts_out_name='concept_prompts') + filename_prompt = GetFilename(path_in_name='image_path', filename_out_name='filename_prompt', include_extension=False) + select_prompt_input = SelectInput(setting_name='concept.prompt_source', out_name='prompts', setting_to_in_name_map={ + 'sample': 'sample_prompts', + 'concept': 'concept_prompts', + 'filename': 'filename_prompt', + }, default_in_name='sample_prompts') + select_random_text = SelectRandomText(texts_in_name='prompts', text_out_name='prompt') + + modules = [load_image, load_sample_prompts, load_concept_prompts, filename_prompt, select_prompt_input, select_random_text] + + if args.masked_training: + modules.append(generate_mask) + modules.append(load_mask) + elif args.model_type.has_mask_input(): + modules.append(generate_mask) + + return modules + + + def _mask_augmentation_modules(self, args: TrainArgs) -> list: + inputs = ['image'] + + if args.model_type.has_depth_input(): + inputs.append('depth') + + circular_mask_shrink = RandomCircularMaskShrink(mask_name='mask', shrink_probability=1.0, shrink_factor_min=0.2, shrink_factor_max=1.0, enabled_in_name='settings.enable_random_circular_mask_shrink') + random_mask_rotate_crop = RandomMaskRotateCrop(mask_name='mask', additional_names=inputs, min_size=args.resolution, min_padding_percent=10, max_padding_percent=30, max_rotate_angle=20, enabled_in_name='settings.enable_random_mask_rotate_crop') + + modules = [] + + if args.masked_training or args.model_type.has_mask_input(): + modules.append(circular_mask_shrink) + + if args.masked_training or args.model_type.has_mask_input(): + modules.append(random_mask_rotate_crop) + + return modules + + + def _aspect_bucketing_in(self, args: TrainArgs): + calc_aspect = CalcAspect(image_in_name='image', resolution_out_name='original_resolution') + + aspect_bucketing = AspectBucketing( + target_resolution=args.resolution, + quantization=64, + resolution_in_name='original_resolution', + scale_resolution_out_name='scale_resolution', + crop_resolution_out_name='crop_resolution', + possible_resolutions_out_name='possible_resolutions' + ) + + single_aspect_calculation = SingleAspectCalculation( + target_resolution=args.resolution, + resolution_in_name='original_resolution', + scale_resolution_out_name='scale_resolution', + crop_resolution_out_name='crop_resolution', + possible_resolutions_out_name='possible_resolutions' + ) + + modules = [calc_aspect] + + if args.aspect_ratio_bucketing: + modules.append(aspect_bucketing) + else: + modules.append(single_aspect_calculation) + + return modules + + + def _crop_modules(self, args: TrainArgs): + scale_crop_image = ScaleCropImage(image_in_name='image', scale_resolution_in_name='scale_resolution', crop_resolution_in_name='crop_resolution', image_out_name='image', enable_crop_jitter_in_name='concept.enable_crop_jitter') + scale_crop_mask = ScaleCropImage(image_in_name='mask', scale_resolution_in_name='scale_resolution', crop_resolution_in_name='crop_resolution', image_out_name='mask', enable_crop_jitter_in_name='concept.enable_crop_jitter') + scale_crop_depth = ScaleCropImage(image_in_name='depth', scale_resolution_in_name='scale_resolution', crop_resolution_in_name='crop_resolution', image_out_name='depth', enable_crop_jitter_in_name='concept.enable_crop_jitter') + + modules = [scale_crop_image] + + if args.masked_training or args.model_type.has_mask_input(): + modules.append(scale_crop_mask) + + if args.model_type.has_depth_input(): + modules.append(scale_crop_depth) + + return modules + + + def _augmentation_modules(self, args: TrainArgs): + inputs = ['image'] + + if args.masked_training or args.model_type.has_mask_input(): + inputs.append('mask') + + if args.model_type.has_depth_input(): + inputs.append('depth') + + random_flip = RandomFlip(names=inputs, enabled_in_name='concept.enable_random_flip') + random_rotate = RandomRotate(names=inputs, enabled_in_name='concept.enable_random_rotate', max_angle_in_name='concept.random_rotate_max_angle') + random_brightness = RandomBrightness(names=['image'], enabled_in_name='concept.enable_random_brightness', max_strength_in_name='concept.random_brightness_max_strength') + random_contrast = RandomContrast(names=['image'], enabled_in_name='concept.enable_random_contrast', max_strength_in_name='concept.random_contrast_max_strength') + random_saturation = RandomSaturation(names=['image'], enabled_in_name='concept.enable_random_saturation', max_strength_in_name='concept.random_saturation_max_strength') + random_hue = RandomHue(names=['image'], enabled_in_name='concept.enable_random_hue', max_strength_in_name='concept.random_hue_max_strength') + + modules = [ + random_flip, + random_rotate, + random_brightness, + random_contrast, + random_saturation, + random_hue, + ] + + return modules + + + def _inpainting_modules(self, args: TrainArgs): + conditioning_image = GenerateMaskedConditioningImage(image_in_name='image', mask_in_name='mask', image_out_name='conditioning_image', image_range_min=0, image_range_max=1) + + modules = [] + + if args.model_type.has_conditioning_image_input(): + modules.append(conditioning_image) + + return modules + + + def _preparation_modules(self, args: TrainArgs, model: StableDiffusionXLModel): + rescale_image = RescaleImageChannels(image_in_name='image', image_out_name='image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) + rescale_conditioning_image = RescaleImageChannels(image_in_name='conditioning_image', image_out_name='conditioning_image', in_range_min=0, in_range_max=1, out_range_min=-1, out_range_max=1) + encode_image = EncodeVAE(in_name='image', out_name='latent_image_distribution', vae=model.vae, override_allow_mixed_precision=False) + downscale_mask = Downscale(in_name='mask', out_name='latent_mask', factor=8) + encode_conditioning_image = EncodeVAE(in_name='conditioning_image', out_name='latent_conditioning_image_distribution', vae=model.vae, override_allow_mixed_precision=False) + downscale_depth = Downscale(in_name='depth', out_name='latent_depth', factor=8) + tokenize_prompt_1 = Tokenize(in_name='prompt', tokens_out_name='tokens_1', mask_out_name='tokens_mask_1', tokenizer=model.tokenizer_1, max_token_length=model.tokenizer_1.model_max_length) + tokenize_prompt_2 = Tokenize(in_name='prompt', tokens_out_name='tokens_2', mask_out_name='tokens_mask_2', tokenizer=model.tokenizer_2, max_token_length=model.tokenizer_2.model_max_length) + encode_prompt_1 = EncodeClipText(in_name='tokens_1', hidden_states_out_name='text_encoder_1_hidden_state', pooled_out_name=None, text_encoder=model.text_encoder_1) + encode_prompt_2 = EncodeClipText(in_name='tokens_2', hidden_states_out_name='text_encoder_2_hidden_state', pooled_out_name='text_encoder_2_pooled_state', text_encoder=model.text_encoder_2) + + modules = [ + rescale_image, encode_image, + tokenize_prompt_1, encode_prompt_1, + tokenize_prompt_2, encode_prompt_2, + ] + + if args.masked_training or args.model_type.has_mask_input(): + modules.append(downscale_mask) + + if args.model_type.has_conditioning_image_input(): + modules.append(rescale_conditioning_image) + modules.append(encode_conditioning_image) + + if args.model_type.has_depth_input(): + modules.append(downscale_depth) + + return modules + + + def _cache_modules(self, args: TrainArgs): + split_names = [ + 'latent_image_distribution', + 'tokens_1', 'text_encoder_1_hidden_state', + 'tokens_2', 'text_encoder_2_hidden_state', 'text_encoder_2_pooled_state', + ] + + if args.masked_training or args.model_type.has_mask_input(): + split_names.append('latent_mask') + + if args.model_type.has_conditioning_image_input(): + split_names.append('latent_conditioning_image_distribution') + + if args.model_type.has_depth_input(): + split_names.append('latent_depth') + + aggregate_names = ['crop_resolution', 'image_path'] + + disk_cache = DiskCache(cache_dir=args.cache_dir, split_names=split_names, aggregate_names=aggregate_names, cached_epochs=args.latent_caching_epochs) + ram_cache = RamCache(names=split_names + aggregate_names) + + modules = [] + + if args.latent_caching: + modules.append(disk_cache) + else: + modules.append(ram_cache) + + return modules + + + def _output_modules(self, args: TrainArgs, model: StableDiffusionXLModel): + output_names = [ + 'image_path', 'latent_image', + 'tokens_1', 'text_encoder_1_hidden_state', + 'tokens_2', 'text_encoder_2_hidden_state', 'text_encoder_2_pooled_state', + ] + + if args.masked_training or args.model_type.has_mask_input(): + output_names.append('latent_mask') + + if args.model_type.has_conditioning_image_input(): + output_names.append('latent_conditioning_image') + + if args.model_type.has_depth_input(): + output_names.append('latent_depth') + + image_sample = SampleVAEDistribution(in_name='latent_image_distribution', out_name='latent_image', mode='mean') + conditioning_image_sample = SampleVAEDistribution(in_name='latent_conditioning_image_distribution', out_name='latent_conditioning_image', mode='mean') + mask_remove = RandomLatentMaskRemove( + latent_mask_name='latent_mask', latent_conditioning_image_name='latent_conditioning_image', + replace_probability=args.unmasked_probability, vae=model.vae, possible_resolutions_in_name='possible_resolutions' + ) + batch_sorting = AspectBatchSorting(resolution_in_name='crop_resolution', names=output_names, batch_size=args.batch_size, sort_resolutions_for_each_epoch=True) + output = OutputPipelineModule(names=output_names) + + modules = [image_sample] + + if args.model_type.has_conditioning_image_input(): + modules.append(conditioning_image_sample) + + if args.model_type.has_mask_input(): + modules.append(mask_remove) + + if args.aspect_ratio_bucketing: + modules.append(batch_sorting) + + modules.append(output) + + return modules + + + def _debug_modules(self, args: TrainArgs, model: StableDiffusionXLModel): + debug_dir = os.path.join(args.debug_dir, "dataloader") + + decode_image = DecodeVAE(in_name='latent_image', out_name='decoded_image', vae=model.vae, override_allow_mixed_precision=False) + decode_conditioning_image = DecodeVAE(in_name='latent_conditioning_image', out_name='decoded_conditioning_image', vae=model.vae, override_allow_mixed_precision=False) + upscale_mask = Upscale(in_name='latent_mask', out_name='decoded_mask', factor=8) + decode_prompt = DecodeTokens(in_name='tokens_1', out_name='decoded_prompt', tokenizer=model.tokenizer_1) + save_image = SaveImage(image_in_name='decoded_image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1) + save_conditioning_image = SaveImage(image_in_name='decoded_conditioning_image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1) + # SaveImage(image_in_name='latent_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), + save_mask = SaveImage(image_in_name='decoded_mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1) + # SaveImage(image_in_name='latent_depth', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), + save_prompt = SaveText(text_in_name='decoded_prompt', original_path_in_name='image_path', path=debug_dir) + + # These modules don't really work, since they are inserted after a sorting operation that does not include this data + # SaveImage(image_in_name='mask', original_path_in_name='image_path', path=debug_dir, in_range_min=0, in_range_max=1), + # SaveImage(image_in_name='depth', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), + # SaveImage(image_in_name='image', original_path_in_name='image_path', path=debug_dir, in_range_min=-1, in_range_max=1), + + modules = [] + + modules.append(decode_image) + modules.append(save_image) + + if args.model_type.has_conditioning_image_input(): + modules.append(decode_conditioning_image) + modules.append(save_conditioning_image) + + if args.masked_training or args.model_type.has_mask_input(): + modules.append(upscale_mask) + modules.append(save_mask) + + modules.append(decode_prompt) + modules.append(save_prompt) + + return modules + + + def create_dataset( + self, + args: TrainArgs, + model: StableDiffusionXLModel, + concepts: list[dict], + train_progress: TrainProgress, + ): + enumerate_input = self._enumerate_input_modules(args) + load_input = self._load_input_modules(args, model) + mask_augmentation = self._mask_augmentation_modules(args) + aspect_bucketing_in = self._aspect_bucketing_in(args) + crop_modules = self._crop_modules(args) + augmentation_modules = self._augmentation_modules(args) + inpainting_modules = self._inpainting_modules(args) + preparation_modules = self._preparation_modules(args, model) + cache_modules = self._cache_modules(args) + output_modules = self._output_modules(args, model) + + debug_modules = self._debug_modules(args, model) + + settings = { + "enable_random_circular_mask_shrink": args.circular_mask_generation, + "enable_random_mask_rotate_crop": args.random_rotate_and_crop, + } + + ds = MGDS( + torch.device(args.train_device), + args.train_dtype.torch_dtype(), + args.train_dtype.enable_mixed_precision(), + concepts, + settings, + [ + enumerate_input, + load_input, + mask_augmentation, + aspect_bucketing_in, + crop_modules, + augmentation_modules, + inpainting_modules, + preparation_modules, + cache_modules, + output_modules, + + debug_modules if args.debug_mode else None, # inserted before output_modules, which contains a sorting operation + ], + batch_size=args.batch_size, + initial_epoch=train_progress.epoch, + initial_epoch_sample=train_progress.epoch_sample, + ) + + return ds diff --git a/modules/model/StableDiffusionXLModel.py b/modules/model/StableDiffusionXLModel.py new file mode 100644 index 00000000..d1fd83e2 --- /dev/null +++ b/modules/model/StableDiffusionXLModel.py @@ -0,0 +1,67 @@ +from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, EulerDiscreteScheduler, \ + StableDiffusionXLPipeline +from transformers import CLIPTextModel, CLIPTokenizer + +from modules.model.BaseModel import BaseModel +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util.TrainProgress import TrainProgress +from modules.util.enum.ModelType import ModelType + + +class StableDiffusionXLModel(BaseModel): + # base model data + model_type: ModelType + tokenizer_1: CLIPTokenizer + tokenizer_2: CLIPTokenizer + noise_scheduler: EulerDiscreteScheduler + text_encoder_1: CLIPTextModel + text_encoder_2: CLIPTextModel + vae: AutoencoderKL + unet: UNet2DConditionModel + + # persistent training data + text_encoder_1_lora: LoRAModuleWrapper | None + text_encoder_2_lora: LoRAModuleWrapper | None + unet_lora: LoRAModuleWrapper | None + + def __init__( + self, + model_type: ModelType, + tokenizer_1: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + noise_scheduler: EulerDiscreteScheduler, + text_encoder_1: CLIPTextModel, + text_encoder_2: CLIPTextModel, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + optimizer_state_dict: dict | None = None, + ema_state_dict: dict | None = None, + train_progress: TrainProgress = None, + text_encoder_1_lora: LoRAModuleWrapper | None = None, + text_encoder_2_lora: LoRAModuleWrapper | None = None, + unet_lora: LoRAModuleWrapper | None = None, + ): + super(StableDiffusionXLModel, self).__init__(model_type, optimizer_state_dict, ema_state_dict, train_progress) + + self.tokenizer_1 = tokenizer_1 + self.tokenizer_2 = tokenizer_2 + self.noise_scheduler = noise_scheduler + self.text_encoder_1 = text_encoder_1 + self.text_encoder_2 = text_encoder_2 + self.vae = vae + self.unet = unet + + self.text_encoder_1_lora = text_encoder_1_lora + self.text_encoder_2_lora = text_encoder_2_lora + self.unet_lora = unet_lora + + def create_pipeline(self) -> DiffusionPipeline: + return StableDiffusionXLPipeline( + vae=self.vae, + text_encoder=self.text_encoder_1, + text_encoder_2=self.text_encoder_2, + tokenizer=self.tokenizer_1, + tokenizer_2=self.tokenizer_2, + unet=self.unet, + scheduler=self.noise_scheduler, + ) diff --git a/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py b/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py new file mode 100644 index 00000000..eb3ec9c0 --- /dev/null +++ b/modules/modelLoader/StableDiffusionXLLoRAModelLoader.py @@ -0,0 +1,121 @@ +import json +import os + +import torch +from safetensors.torch import load_file +from torch import Tensor + +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelLoader.BaseModelLoader import BaseModelLoader +from modules.modelLoader.StableDiffusionXLModelLoader import StableDiffusionXLModelLoader +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util.TrainProgress import TrainProgress +from modules.util.enum.ModelType import ModelType + + +class StableDiffusionXLLoRAModelLoader(BaseModelLoader): + def __init__(self): + super(StableDiffusionXLLoRAModelLoader, self).__init__() + + @staticmethod + def __get_rank(state_dict: dict) -> int: + for name, state in state_dict.items(): + if "lora_down.weight" in name: + return state.shape[0] + + @staticmethod + def __init_lora(model: StableDiffusionXLModel, state_dict: dict[str, Tensor]): + rank = StableDiffusionXLLoRAModelLoader.__get_rank(state_dict) + + model.text_encoder_1_lora = LoRAModuleWrapper( + orig_module=model.text_encoder_1, + rank=rank, + prefix="lora_te", + ).to(dtype=torch.float32) + model.text_encoder_1_lora.load_state_dict(state_dict) + + model.text_encoder_2_lora = LoRAModuleWrapper( + orig_module=model.text_encoder_2, + rank=rank, + prefix="lora_te_2", + ).to(dtype=torch.float32) + model.text_encoder_2_lora.load_state_dict(state_dict) + + model.unet_lora = LoRAModuleWrapper( + orig_module=model.unet, + rank=rank, + prefix="lora_unet", + module_filter=["attentions"], + ).to(dtype=torch.float32) + model.unet_lora.load_state_dict(state_dict) + + @staticmethod + def __load_safetensors(model: StableDiffusionXLModel, lora_name: str) -> bool: + try: + state_dict = load_file(lora_name) + StableDiffusionXLLoRAModelLoader.__init_lora(model, state_dict) + return True + except: + return False + + @staticmethod + def __load_ckpt(model: StableDiffusionXLModel, lora_name: str) -> bool: + try: + state_dict = torch.load(lora_name) + StableDiffusionXLLoRAModelLoader.__init_lora(model, state_dict) + return True + except: + return False + + @staticmethod + def __load_internal(model: StableDiffusionXLModel, lora_name: str) -> bool: + try: + with open(os.path.join(lora_name, "meta.json"), "r") as meta_file: + meta = json.load(meta_file) + train_progress = TrainProgress( + epoch=meta['train_progress']['epoch'], + epoch_step=meta['train_progress']['epoch_step'], + epoch_sample=meta['train_progress']['epoch_sample'], + global_step=meta['train_progress']['global_step'], + ) + + # embedding model + loaded = StableDiffusionXLLoRAModelLoader.__load_ckpt( + model, + os.path.join(lora_name, "lora", "lora.pt") + ) + if not loaded: + return False + + # optimizer + try: + model.optimizer_state_dict = torch.load(os.path.join(lora_name, "optimizer", "optimizer.pt")) + except FileNotFoundError: + pass + + # ema + try: + model.ema_state_dict = torch.load(os.path.join(lora_name, "ema", "ema.pt")) + except FileNotFoundError: + pass + + # meta + model.train_progress = train_progress + + return True + except: + return False + + def load(self, model_type: ModelType, base_model_name: str, extra_model_name: str | None) -> StableDiffusionXLModel | None: + base_model_loader = StableDiffusionXLModelLoader() + model = base_model_loader.load(model_type, base_model_name, None) + + lora_loaded = self.__load_internal(model, extra_model_name) + + if not lora_loaded: + lora_loaded = self.__load_ckpt(model, extra_model_name) + + if not lora_loaded: + lora_loaded = self.__load_safetensors(model, extra_model_name) + + return model diff --git a/modules/modelLoader/StableDiffusionXLModelLoader.py b/modules/modelLoader/StableDiffusionXLModelLoader.py new file mode 100644 index 00000000..f73fdc8c --- /dev/null +++ b/modules/modelLoader/StableDiffusionXLModelLoader.py @@ -0,0 +1,196 @@ +import json +import os + +import torch +import yaml +from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler, DDIMScheduler +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection + +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelLoader.BaseModelLoader import BaseModelLoader +from modules.util.TrainProgress import TrainProgress +from modules.util.enum.ModelType import ModelType + + +class StableDiffusionXLModelLoader(BaseModelLoader): + def __init__(self): + super(StableDiffusionXLModelLoader, self).__init__() + + @staticmethod + def __load_internal(model_type: ModelType, base_model_name: str) -> StableDiffusionXLModel | None: + try: + with open(os.path.join(base_model_name, "meta.json"), "r") as meta_file: + meta = json.load(meta_file) + train_progress = TrainProgress( + epoch=meta['train_progress']['epoch'], + epoch_step=meta['train_progress']['epoch_step'], + epoch_sample=meta['train_progress']['epoch_sample'], + global_step=meta['train_progress']['global_step'], + ) + + # base model + model = StableDiffusionXLModelLoader.__load_diffusers(model_type, base_model_name) + + # optimizer + try: + model.optimizer_state_dict = torch.load(os.path.join(base_model_name, "optimizer", "optimizer.pt")) + except FileNotFoundError: + pass + + # ema + try: + model.ema_state_dict = torch.load(os.path.join(base_model_name, "ema", "ema.pt")) + except FileNotFoundError: + pass + + # meta + model.train_progress = train_progress + + return model + except: + return None + + @staticmethod + def __load_diffusers(model_type: ModelType, base_model_name: str) -> StableDiffusionXLModel | None: + try: + tokenizer_1 = CLIPTokenizer.from_pretrained( + base_model_name, + subfolder="tokenizer", + ) + + tokenizer_2 = CLIPTokenizer.from_pretrained( + base_model_name, + subfolder="tokenizer_2", + ) + + noise_scheduler = DDIMScheduler.from_pretrained( + base_model_name, + subfolder="scheduler", + ) + + text_encoder_1 = CLIPTextModel.from_pretrained( + base_model_name, + subfolder="text_encoder", + torch_dtype=torch.float32, + ) + + text_encoder_2 = CLIPTextModelWithProjection.from_pretrained( + base_model_name, + subfolder="text_encoder_2", + torch_dtype=torch.float32, + ) + + vae = AutoencoderKL.from_pretrained( + base_model_name, + subfolder="vae", + torch_dtype=torch.float32, + ) + + unet = UNet2DConditionModel.from_pretrained( + base_model_name, + subfolder="unet", + torch_dtype=torch.float32, + ) + + return StableDiffusionXLModel( + model_type=model_type, + tokenizer_1=tokenizer_1, + tokenizer_2=tokenizer_2, + noise_scheduler=noise_scheduler, + text_encoder_1=text_encoder_1, + text_encoder_2=text_encoder_2, + vae=vae, + unet=unet, + ) + except: + return None + + @staticmethod + def __load_ckpt(model_type: ModelType, base_model_name: str) -> StableDiffusionXLModel | None: + try: + yaml_name = os.path.splitext(base_model_name)[0] + '.yaml' + if not os.path.exists(yaml_name): + yaml_name = os.path.splitext(base_model_name)[0] + '.yml' + if not os.path.exists(yaml_name): + yaml_name = StableDiffusionXLModelLoader.__default_yaml_name(model_type) + + pipeline = download_from_original_stable_diffusion_ckpt( + checkpoint_path=base_model_name, + original_config_file=yaml_name, + load_safety_checker=False, + ) + + with open(yaml_name, "r") as f: + sd_config = yaml.safe_load(f) + + return StableDiffusionXLModel( + model_type=model_type, + tokenizer=pipeline.tokenizer, + noise_scheduler=pipeline.scheduler, + text_encoder=pipeline.text_encoder.to(dtype=torch.float32), + vae=pipeline.vae.to(dtype=torch.float32), + unet=pipeline.unet.to(dtype=torch.float32), + image_depth_processor=None, # TODO + depth_estimator=None, # TODO + sd_config=sd_config, + ) + except: + return None + + @staticmethod + def __load_safetensors(model_type: ModelType, base_model_name: str) -> StableDiffusionXLModel | None: + try: + yaml_name = os.path.splitext(base_model_name)[0] + '.yaml' + if not os.path.exists(yaml_name): + yaml_name = os.path.splitext(base_model_name)[0] + '.yml' + if not os.path.exists(yaml_name): + yaml_name = StableDiffusionModelLoader.__default_yaml_name(model_type) + + pipeline = download_from_original_stable_diffusion_ckpt( + checkpoint_path=base_model_name, + original_config_file=yaml_name, + load_safety_checker=False, + from_safetensors=True, + ) + + with open(yaml_name, "r") as f: + sd_config = yaml.safe_load(f) + + return StableDiffusionXLModel( + model_type=model_type, + tokenizer=pipeline.tokenizer, + noise_scheduler=pipeline.scheduler, + text_encoder=pipeline.text_encoder.to(dtype=torch.float32), + vae=pipeline.vae.to(dtype=torch.float32), + unet=pipeline.unet.to(dtype=torch.float32), + image_depth_processor=None, # TODO + depth_estimator=None, # TODO + sd_config=sd_config, + ) + except: + return None + + def load( + self, + model_type: ModelType, + base_model_name: str, + extra_model_name: str | None + ) -> StableDiffusionXLModel | None: + # model = self.__load_internal(model_type, base_model_name) + # if model is not None: + # return model + + model = self.__load_diffusers(model_type, base_model_name) + if model is not None: + return model + + # model = self.__load_safetensors(model_type, base_model_name) + # if model is not None: + # return model + # + # model = self.__load_ckpt(model_type, base_model_name) + # if model is not None: + # return model + + raise Exception("could not load model: " + base_model_name) diff --git a/modules/modelSampler/StableDiffusionXLSampler.py b/modules/modelSampler/StableDiffusionXLSampler.py new file mode 100644 index 00000000..f17e0cce --- /dev/null +++ b/modules/modelSampler/StableDiffusionXLSampler.py @@ -0,0 +1,72 @@ +import os +from pathlib import Path +from typing import Callable + +import torch +from PIL.Image import Image + +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSampler.BaseModelSampler import BaseModelSampler +from modules.util.enum.ModelType import ModelType + + +class StableDiffusionXLSampler(BaseModelSampler): + def __init__(self, model: StableDiffusionXLModel, model_type: ModelType, train_device: torch.device): + self.model = model + self.model_type = model_type + self.train_device = train_device + self.pipeline = model.create_pipeline() + + @torch.no_grad() + def __sample_base( + self, + prompt: str, + resolution: tuple[int, int], + seed: int, + steps: int, + cfg_scale: float, + cfg_rescale: float = 0.7, + text_encoder_layer_skip: int = 0, + force_last_timestep: bool = False, + ) -> Image: + generator = torch.Generator(device=self.train_device) + generator.manual_seed(seed) + + height, width = resolution + + images = self.pipeline( + generator=generator, + prompt=prompt, + height=height, + width=width, + num_inference_steps=steps, + guidance_scale=cfg_scale, + ).images + + return images[0] + + def sample( + self, + prompt: str, + resolution: tuple[int, int], + seed: int, + destination: str, + text_encoder_layer_skip: int, + force_last_timestep: bool = False, + on_sample: Callable[[Image], None] = lambda _: None, + ): + image = self.__sample_base( + prompt=prompt, + resolution=resolution, + seed=seed, + steps=20, + cfg_scale=7, + cfg_rescale=0.7 if force_last_timestep else 0.0, + text_encoder_layer_skip=text_encoder_layer_skip, + force_last_timestep=force_last_timestep + ) + + os.makedirs(Path(destination).parent.absolute(), exist_ok=True) + image.save(destination) + + on_sample(image) diff --git a/modules/modelSaver/StableDiffusionLoRAModelSaver.py b/modules/modelSaver/StableDiffusionLoRAModelSaver.py index 6e8ee9d6..df4c44fc 100644 --- a/modules/modelSaver/StableDiffusionLoRAModelSaver.py +++ b/modules/modelSaver/StableDiffusionLoRAModelSaver.py @@ -31,7 +31,6 @@ def __save_ckpt( destination: str, dtype: torch.dtype, ): - state_dict = StableDiffusionLoRAModelSaver.__get_state_dict(model) save_state_dict = BaseModelSaver._convert_state_dict_dtype(state_dict, dtype) @@ -44,7 +43,6 @@ def __save_safetensors( destination: str, dtype: torch.dtype, ): - state_dict = StableDiffusionLoRAModelSaver.__get_state_dict(model) save_state_dict = BaseModelSaver._convert_state_dict_dtype(state_dict, dtype) diff --git a/modules/modelSaver/StableDiffusionXLLoRAModelSaver.py b/modules/modelSaver/StableDiffusionXLLoRAModelSaver.py new file mode 100644 index 00000000..8cda5865 --- /dev/null +++ b/modules/modelSaver/StableDiffusionXLLoRAModelSaver.py @@ -0,0 +1,120 @@ +import json +import os.path +from pathlib import Path + +import torch +from safetensors.torch import save_file +from torch import Tensor + +from modules.model.BaseModel import BaseModel +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSaver.BaseModelSaver import BaseModelSaver +from modules.util.enum.ModelFormat import ModelFormat +from modules.util.enum.ModelType import ModelType + + +class StableDiffusionXLLoRAModelSaver(BaseModelSaver): + + @staticmethod + def __get_state_dict(model: StableDiffusionXLModel) -> dict[str, Tensor]: + state_dict = {} + if model.text_encoder_1_lora is not None: + state_dict |= model.text_encoder_1_lora.state_dict() + if model.text_encoder_2_lora is not None: + state_dict |= model.text_encoder_2_lora.state_dict() + if model.unet_lora is not None: + state_dict |= model.unet_lora.state_dict() + + return state_dict + + @staticmethod + def __save_ckpt( + model: StableDiffusionXLModel, + destination: str, + dtype: torch.dtype, + ): + state_dict = StableDiffusionXLLoRAModelSaver.__get_state_dict(model) + save_state_dict = BaseModelSaver._convert_state_dict_dtype(state_dict, dtype) + + os.makedirs(Path(destination).parent.absolute(), exist_ok=True) + torch.save(save_state_dict, destination) + + @staticmethod + def __save_safetensors( + model: StableDiffusionXLModel, + destination: str, + dtype: torch.dtype, + ): + state_dict = StableDiffusionXLLoRAModelSaver.__get_state_dict(model) + save_state_dict = BaseModelSaver._convert_state_dict_dtype(state_dict, dtype) + + os.makedirs(Path(destination).parent.absolute(), exist_ok=True) + save_file(save_state_dict, destination) + + @staticmethod + def __save_internal( + model: StableDiffusionXLModel, + destination: str, + ): + text_encoder_1_dtype = None if model.text_encoder_1_lora is None else \ + model.text_encoder_1_lora.parameters()[0].data.dtype + + text_encoder_2_dtype = None if model.text_encoder_2_lora is None else \ + model.text_encoder_2_lora.parameters()[0].data.dtype + + unet_dtype = None if model.unet_lora is None else \ + model.unet_lora.parameters()[0].data.dtype + + if text_encoder_1_dtype is not None and text_encoder_1_dtype != torch.float32 \ + or text_encoder_2_dtype is not None and text_encoder_2_dtype != torch.float32 \ + or unet_dtype is not None and unet_dtype != torch.float32: + # The internal model format requires float32 weights. + # Other formats don't have the required precision for training. + raise ValueError("Model weights need to be in float32 format. Something has gone wrong!") + + os.makedirs(destination, exist_ok=True) + + # lora + StableDiffusionXLLoRAModelSaver.__save_ckpt( + model, + os.path.join(destination, "lora", "lora.pt"), + torch.float32 + ) + + # optimizer + os.makedirs(os.path.join(destination, "optimizer"), exist_ok=True) + torch.save(model.optimizer.state_dict(), os.path.join(destination, "optimizer", "optimizer.pt")) + + # ema + if model.ema: + os.makedirs(os.path.join(destination, "ema"), exist_ok=True) + torch.save(model.ema.state_dict(), os.path.join(destination, "ema", "ema.pt")) + + # meta + with open(os.path.join(destination, "meta.json"), "w") as meta_file: + json.dump({ + 'train_progress': { + 'epoch': model.train_progress.epoch, + 'epoch_step': model.train_progress.epoch_step, + 'epoch_sample': model.train_progress.epoch_sample, + 'global_step': model.train_progress.global_step, + }, + }, meta_file) + + def save( + self, + model: BaseModel, + model_type: ModelType, + output_model_format: ModelFormat, + output_model_destination: str, + dtype: torch.dtype, + ): + match output_model_format: + case ModelFormat.DIFFUSERS: + raise NotImplementedError + case ModelFormat.CKPT: + self.__save_ckpt(model, output_model_destination, dtype) + case ModelFormat.SAFETENSORS: + self.__save_safetensors(model, output_model_destination, dtype) + case ModelFormat.INTERNAL: + self.__save_internal(model, output_model_destination) diff --git a/modules/modelSetup/BaseModelSetup.py b/modules/modelSetup/BaseModelSetup.py index 05e0a8c6..5abb2f98 100644 --- a/modules/modelSetup/BaseModelSetup.py +++ b/modules/modelSetup/BaseModelSetup.py @@ -39,6 +39,14 @@ def save_image(self, image_tensor: Tensor, directory: str, name: str, step: int) image = t(image_tensor.squeeze()) image.save(path) + def project_latent_to_image(self, latent_tensor: Tensor): + generator = torch.Generator(device=latent_tensor.device) + generator.manual_seed(42) + weight = torch.randn((3, 4, 1, 1), generator=generator, device=latent_tensor.device, dtype=latent_tensor.dtype) + + with torch.no_grad(): + return torch.nn.functional.conv2d(latent_tensor, weight) / 3.0 + @abstractmethod def create_parameters( self, diff --git a/modules/modelSetup/BaseStableDiffusionXLSetup.py b/modules/modelSetup/BaseStableDiffusionXLSetup.py new file mode 100644 index 00000000..51ee6725 --- /dev/null +++ b/modules/modelSetup/BaseStableDiffusionXLSetup.py @@ -0,0 +1,275 @@ +from abc import ABCMeta +from typing import Callable, Optional, Dict, Any + +import torch +import torch.nn.functional as F +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.attention_processor import AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 +from diffusers.utils import is_xformers_available +from torch import Tensor, nn + +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSetup.BaseModelSetup import BaseModelSetup +from modules.util import loss_util +from modules.util.TrainProgress import TrainProgress +from modules.util.args.TrainArgs import TrainArgs +from modules.util.enum.AttentionMechanism import AttentionMechanism + + +class BaseStableDiffusionXLSetup(BaseModelSetup, metaclass=ABCMeta): + + def __create_basic_transformer_block_forward(self, orig_module) -> Callable: + orig_forward = orig_module.forward + + def forward( + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + return torch.utils.checkpoint.checkpoint( + orig_forward, + hidden_states, # hidden_states + attention_mask, # attention_mask + encoder_hidden_states, # encoder_hidden_states + encoder_attention_mask, # encoder_attention_mask + timestep, # timestep + cross_attention_kwargs, # cross_attention_kwargs + class_labels, # class_labels + use_reentrant=False + ) + + return forward + + def __enable_checkpointing_for_transformer_blocks(self, orig_module: nn.Module): + for name, child_module in orig_module.named_modules(): + if isinstance(child_module, BasicTransformerBlock): + child_module.forward = self.__create_basic_transformer_block_forward(child_module) + + def setup_optimizations( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ): + if args.attention_mechanism == AttentionMechanism.DEFAULT: + model.unet.set_attn_processor(AttnProcessor()) + pass + elif args.attention_mechanism == AttentionMechanism.XFORMERS and is_xformers_available(): + try: + model.unet.set_attn_processor(XFormersAttnProcessor()) + model.vae.enable_xformers_memory_efficient_attention() + except Exception as e: + print( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + elif args.attention_mechanism == AttentionMechanism.SDP: + model.unet.set_attn_processor(AttnProcessor2_0()) + + if is_xformers_available(): + try: + model.vae.enable_xformers_memory_efficient_attention() + except Exception as e: + print( + "Could not enable memory efficient attention. Make sure xformers is installed" + f" correctly and a GPU is available: {e}" + ) + + model.unet.enable_gradient_checkpointing() + self.__enable_checkpointing_for_transformer_blocks(model.unet) + model.text_encoder_1.gradient_checkpointing_enable() + model.text_encoder_2.gradient_checkpointing_enable() + + def predict( + self, + model: StableDiffusionXLModel, + batch: dict, + args: TrainArgs, + train_progress: TrainProgress + ) -> dict: + vae_scaling_factor = model.vae.config['scaling_factor'] + model.noise_scheduler.set_timesteps(model.noise_scheduler.config['num_train_timesteps']) + + latent_image = batch['latent_image'] + scaled_latent_image = latent_image * vae_scaling_factor + + generator = torch.Generator(device=args.train_device) + generator.manual_seed(train_progress.global_step) + + if args.offset_noise_weight > 0: + normal_noise = torch.randn( + scaled_latent_image.shape, generator=generator, device=args.train_device, + dtype=args.train_dtype.torch_dtype() + ) + offset_noise = torch.randn( + scaled_latent_image.shape[0], scaled_latent_image.shape[1], 1, 1, + generator=generator, device=args.train_device, dtype=args.train_dtype.torch_dtype() + ) + latent_noise = normal_noise + (args.offset_noise_weight * offset_noise) + else: + latent_noise = torch.randn( + scaled_latent_image.shape, generator=generator, device=args.train_device, + dtype=args.train_dtype.torch_dtype() + ) + + timestep = torch.randint( + low=0, + high=int(model.noise_scheduler.config['num_train_timesteps'] * args.max_noising_strength), + size=(scaled_latent_image.shape[0],), + device=scaled_latent_image.device, + ).long() + + scaled_noisy_latent_image = model.noise_scheduler.add_noise( + original_samples=scaled_latent_image, noise=latent_noise, timesteps=timestep + ) + + if args.train_text_encoder: + text_encoder_1_output = model.text_encoder_1( + batch['tokens_1'], output_hidden_states=True, return_dict=True + ) + text_encoder_1_output = text_encoder_1_output.hidden_states[-2] + + text_encoder_2_output = model.text_encoder_2( + batch['tokens_2'], output_hidden_states=True, return_dict=True + ) + pooled_text_encoder_2_output = text_encoder_2_output.text_embeds + text_encoder_2_output = text_encoder_2_output.hidden_states[-2] + else: + text_encoder_1_output = batch['text_encoder_1_hidden_state'][-2] + text_encoder_2_output = batch['text_encoder_2_hidden_state'][-2] + pooled_text_encoder_2_output = batch['text_encoder_2_pooled_state'] + + text_encoder_output = torch.concat([text_encoder_1_output, text_encoder_2_output], dim=-1) + + original_height = 1024 + original_width = 1024 + crops_coords_top = 0 + crops_coords_left = 0 + target_height = 1024 + target_width = 1024 + add_time_ids = list( + (original_height, original_width) + + (crops_coords_top, crops_coords_left) + + (target_height, target_width) + ) + add_time_ids = torch.tensor( + [add_time_ids], + dtype=scaled_noisy_latent_image.dtype, + device=scaled_noisy_latent_image.device, + ) + add_time_ids = torch.concat([add_time_ids] * pooled_text_encoder_2_output.shape[0]) + + latent_input = scaled_noisy_latent_image + added_cond_kwargs = {"text_embeds": pooled_text_encoder_2_output, "time_ids": add_time_ids} + predicted_latent_noise = model.unet( + sample=latent_input, + timestep=timestep, + encoder_hidden_states=text_encoder_output, + added_cond_kwargs=added_cond_kwargs, + ).sample + + model_output_data = {} + + if model.noise_scheduler.config.prediction_type == 'epsilon': + model_output_data = { + 'predicted': predicted_latent_noise, + 'target': latent_noise, + } + elif model.noise_scheduler.config.prediction_type == 'v_prediction': + target_velocity = model.noise_scheduler.get_velocity(scaled_latent_image, latent_noise, timestep) + model_output_data = { + 'predicted': predicted_latent_noise, + 'target': target_velocity, + } + + if args.debug_mode: + with torch.autocast(self.train_device.type, enabled=False): + with torch.no_grad(): + # noise + self.save_image( + self.project_latent_to_image(latent_noise).clamp(-1, 1), + args.debug_dir + "/training_batches", + "1-noise", + train_progress.global_step + ) + + # predicted noise + self.save_image( + self.project_latent_to_image(predicted_latent_noise).clamp(-1, 1), + args.debug_dir + "/training_batches", + "2-predicted_noise", + train_progress.global_step + ) + + # noisy image + self.save_image( + self.project_latent_to_image(scaled_noisy_latent_image).clamp(-1, 1), + args.debug_dir + "/training_batches", + "3-noisy_image", + train_progress.global_step + ) + + # predicted image + alphas_cumprod = model.noise_scheduler.alphas_cumprod.to(args.train_device) + sqrt_alpha_prod = alphas_cumprod[timestep] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten().reshape(-1, 1, 1, 1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timestep]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten().reshape(-1, 1, 1, 1) + + scaled_predicted_latent_image = \ + (scaled_noisy_latent_image - predicted_latent_noise * sqrt_one_minus_alpha_prod) \ + / sqrt_alpha_prod + self.save_image( + self.project_latent_to_image(scaled_predicted_latent_image).clamp(-1, 1), + args.debug_dir + "/training_batches", + "4-predicted_image", + model.train_progress.global_step + ) + + # image + self.save_image( + self.project_latent_to_image(scaled_latent_image).clamp(-1, 1), + args.debug_dir + "/training_batches", + "5-image", + model.train_progress.global_step + ) + + return model_output_data + + def calculate_loss( + self, + model: StableDiffusionXLModel, + batch: dict, + data: dict, + args: TrainArgs, + ) -> Tensor: + predicted = data['predicted'] + target = data['target'] + + # TODO: don't disable masked loss functions when has_conditioning_image_input is true. + # This breaks if only the VAE is trained, but was loaded from an inpainting checkpoint + if args.masked_training and not args.model_type.has_conditioning_image_input(): + losses = loss_util.masked_loss( + F.mse_loss, + predicted, + target, + batch['latent_mask'], + args.unmasked_weight, + args.normalize_masked_area_loss + ).mean([1, 2, 3]) + else: + losses = F.mse_loss( + predicted, + target, + reduction='none' + ).mean([1, 2, 3]) + + if args.normalize_masked_area_loss: + clamped_mask = torch.clamp(batch['latent_mask'], args.unmasked_weight, 1) + losses = losses / clamped_mask.mean(dim=(1, 2, 3)) + + return losses.mean() diff --git a/modules/modelSetup/StableDiffusionLoRASetup.py b/modules/modelSetup/StableDiffusionLoRASetup.py index 790562a1..6cefc379 100644 --- a/modules/modelSetup/StableDiffusionLoRASetup.py +++ b/modules/modelSetup/StableDiffusionLoRASetup.py @@ -134,7 +134,7 @@ def setup_train_device( args: TrainArgs, ): model.text_encoder.to(self.train_device) - model.vae.to(self.temp_device) + model.vae.to(self.train_device if self.debug_mode else self.temp_device) model.unet.to(self.train_device) if model.depth_estimator is not None: model.depth_estimator.to(self.temp_device) diff --git a/modules/modelSetup/StableDiffusionXLFineTuneSetup.py b/modules/modelSetup/StableDiffusionXLFineTuneSetup.py new file mode 100644 index 00000000..c71b454e --- /dev/null +++ b/modules/modelSetup/StableDiffusionXLFineTuneSetup.py @@ -0,0 +1,141 @@ +from typing import Iterable + +import torch +from torch.nn import Parameter + +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup +from modules.util import create +from modules.util.TrainProgress import TrainProgress +from modules.util.args.TrainArgs import TrainArgs + + +class StableDiffusionXLFineTuneSetup(BaseStableDiffusionXLSetup): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super(StableDiffusionXLFineTuneSetup, self).__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ) -> Iterable[Parameter]: + params = list() + + if args.train_text_encoder: + params += list(model.text_encoder_1.parameters()) + + if args.train_unet: + params += list(model.unet.parameters()) + + return params + + def create_parameters_for_optimizer( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ) -> Iterable[Parameter] | list[dict]: + param_groups = list() + + if args.train_text_encoder: + lr = args.text_encoder_learning_rate if args.text_encoder_learning_rate is not None else args.learning_rate + param_groups.append({ + 'params': model.text_encoder_1.parameters(), + 'lr': lr, + 'initial_lr': lr, + }) + param_groups.append({ + 'params': model.text_encoder_2.parameters(), + 'lr': lr, + 'initial_lr': lr, + }) + + if args.train_unet: + lr = args.unet_learning_rate if args.unet_learning_rate is not None else args.learning_rate + param_groups.append({ + 'params': model.unet.parameters(), + 'lr': lr, + 'initial_lr': lr, + }) + + return param_groups + + def setup_model( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ): + train_text_encoder = args.train_text_encoder and (model.train_progress.epoch < args.train_text_encoder_epochs) + model.text_encoder_1.requires_grad_(train_text_encoder) + + train_unet = args.train_unet and (model.train_progress.epoch < args.train_unet_epochs) + model.unet.requires_grad_(train_unet) + + model.vae.requires_grad_(False) + + model.optimizer = create.create_optimizer( + self.create_parameters_for_optimizer(model, args), model.optimizer_state_dict, args + ) + del model.optimizer_state_dict + + model.ema = create.create_ema( + self.create_parameters(model, args), model.ema_state_dict, args + ) + del model.ema_state_dict + + self.setup_optimizations(model, args) + + def setup_eval_device( + self, + model: StableDiffusionXLModel + ): + model.text_encoder_1.to(self.train_device) + model.text_encoder_2.to(self.train_device) + model.vae.to(self.train_device) + model.unet.to(self.train_device) + + model.text_encoder_1.eval() + model.text_encoder_2.eval() + model.vae.eval() + model.unet.eval() + + def setup_train_device( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ): + model.text_encoder_1.to(self.train_device if args.train_text_encoder else self.temp_device) + model.text_encoder_2.to(self.train_device if args.train_text_encoder else self.temp_device) + model.vae.to(self.temp_device) + model.unet.to(self.train_device) + + if args.train_text_encoder: + model.text_encoder_1.train() + model.text_encoder_2.train() + else: + model.text_encoder_1.eval() + model.text_encoder_2.eval() + + model.vae.train() + model.unet.train() + + def after_optimizer_step( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + train_progress: TrainProgress + ): + train_text_encoder = args.train_text_encoder and (model.train_progress.epoch < args.train_text_encoder_epochs) + model.text_encoder_1.requires_grad_(train_text_encoder) + model.text_encoder_2.requires_grad_(train_text_encoder) + + train_unet = args.train_unet and (model.train_progress.epoch < args.train_unet_epochs) + model.unet.requires_grad_(train_unet) diff --git a/modules/modelSetup/StableDiffusionXLLoRASetup.py b/modules/modelSetup/StableDiffusionXLLoRASetup.py new file mode 100644 index 00000000..dfd7808e --- /dev/null +++ b/modules/modelSetup/StableDiffusionXLLoRASetup.py @@ -0,0 +1,183 @@ +from typing import Iterable + +import torch +from torch.nn import Parameter + +from modules.model.StableDiffusionXLModel import StableDiffusionXLModel +from modules.modelSetup.BaseStableDiffusionXLSetup import BaseStableDiffusionXLSetup +from modules.module.LoRAModule import LoRAModuleWrapper +from modules.util import create +from modules.util.TrainProgress import TrainProgress +from modules.util.args.TrainArgs import TrainArgs + + +class StableDiffusionXLLoRASetup(BaseStableDiffusionXLSetup): + def __init__( + self, + train_device: torch.device, + temp_device: torch.device, + debug_mode: bool, + ): + super(StableDiffusionXLLoRASetup, self).__init__( + train_device=train_device, + temp_device=temp_device, + debug_mode=debug_mode, + ) + + def create_parameters( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ) -> Iterable[Parameter]: + params = list() + + if args.train_text_encoder: + params += list(model.text_encoder_1_lora.parameters()) + params += list(model.text_encoder_2_lora.parameters()) + + if args.train_unet: + params += list(model.unet_lora.parameters()) + + return params + + def create_parameters_for_optimizer( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ) -> Iterable[Parameter] | list[dict]: + param_groups = list() + + if args.train_text_encoder: + lr = args.text_encoder_learning_rate if args.text_encoder_learning_rate is not None else args.learning_rate + param_groups.append({ + 'params': model.text_encoder_1_lora.parameters(), + 'lr': lr, + 'initial_lr': lr, + }) + param_groups.append({ + 'params': model.text_encoder_2_lora.parameters(), + 'lr': lr, + 'initial_lr': lr, + }) + + if args.train_unet: + lr = args.unet_learning_rate if args.unet_learning_rate is not None else args.learning_rate + param_groups.append({ + 'params': model.unet_lora.parameters(), + 'lr': lr, + 'initial_lr': lr, + }) + + return param_groups + + def setup_model( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ): + if model.text_encoder_1_lora is None: + model.text_encoder_1_lora = LoRAModuleWrapper( + model.text_encoder_1, args.lora_rank, "lora_te", args.lora_alpha + ) + + if model.text_encoder_2_lora is None: + model.text_encoder_2_lora = LoRAModuleWrapper( + model.text_encoder_2, args.lora_rank, "lora_te_2", args.lora_alpha + ) + + if model.unet_lora is None: + model.unet_lora = LoRAModuleWrapper( + model.unet, args.lora_rank, "lora_unet", args.lora_alpha, ["attentions"] + ) + + model.text_encoder_1.requires_grad_(False) + model.text_encoder_2.requires_grad_(False) + model.unet.requires_grad_(False) + model.vae.requires_grad_(False) + + train_text_encoder = args.train_text_encoder and (model.train_progress.epoch < args.train_text_encoder_epochs) + model.text_encoder_1_lora.requires_grad_(train_text_encoder) + model.text_encoder_2_lora.requires_grad_(train_text_encoder) + + train_unet = args.train_unet and (model.train_progress.epoch < args.train_unet_epochs) + model.unet_lora.requires_grad_(train_unet) + + model.text_encoder_1_lora.hook_to_module() + model.text_encoder_2_lora.hook_to_module() + model.unet_lora.hook_to_module() + + model.optimizer = create.create_optimizer( + self.create_parameters_for_optimizer(model, args), model.optimizer_state_dict, args + ) + del model.optimizer_state_dict + + model.ema = create.create_ema( + self.create_parameters(model, args), model.ema_state_dict, args + ) + del model.ema_state_dict + + self.setup_optimizations(model, args) + + def setup_eval_device( + self, + model: StableDiffusionXLModel + ): + model.text_encoder_1.to(self.train_device) + model.text_encoder_2.to(self.train_device) + model.vae.to(self.train_device) + model.unet.to(self.train_device) + + if model.text_encoder_1_lora is not None: + model.text_encoder_1_lora.to(self.train_device) + + if model.text_encoder_2_lora is not None: + model.text_encoder_2_lora.to(self.train_device) + + if model.unet_lora is not None: + model.unet_lora.to(self.train_device) + + model.text_encoder_1.eval() + model.text_encoder_2.eval() + model.vae.eval() + model.unet.eval() + + def setup_train_device( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + ): + model.text_encoder_1.to(self.train_device if args.train_text_encoder else self.temp_device) + model.text_encoder_2.to(self.train_device if args.train_text_encoder else self.temp_device) + model.vae.to(self.temp_device) + model.unet.to(self.train_device) + + if model.text_encoder_1_lora is not None and args.train_text_encoder: + model.text_encoder_1_lora.to(self.train_device) + + if model.text_encoder_2_lora is not None and args.train_text_encoder: + model.text_encoder_2_lora.to(self.train_device) + + if model.unet_lora is not None: + model.unet_lora.to(self.train_device) + + if args.train_text_encoder: + model.text_encoder_1.train() + model.text_encoder_2.train() + else: + model.text_encoder_1.eval() + model.text_encoder_2.eval() + model.vae.eval() + model.unet.train() + + def after_optimizer_step( + self, + model: StableDiffusionXLModel, + args: TrainArgs, + train_progress: TrainProgress + ): + train_text_encoder = args.train_text_encoder and (model.train_progress.epoch < args.train_text_encoder_epochs) + model.text_encoder_1_lora.requires_grad_(train_text_encoder) + model.text_encoder_2_lora.requires_grad_(train_text_encoder) + + train_unet = args.train_unet and (model.train_progress.epoch < args.train_unet_epochs) + model.unet_lora.requires_grad_(train_unet) diff --git a/modules/trainer/GenericTrainer.py b/modules/trainer/GenericTrainer.py index 16a76d88..e3e7e479 100644 --- a/modules/trainer/GenericTrainer.py +++ b/modules/trainer/GenericTrainer.py @@ -1,3 +1,4 @@ +import gc import json import os import subprocess @@ -96,6 +97,11 @@ def start(self): self.parameters = list(self.model_setup.create_parameters(self.model, self.args)) + def __gc(self): + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + def __enqueue_sample_during_training(self, fun: Callable): self.sample_queue.append(fun) @@ -134,10 +140,10 @@ def on_sample(image: Image): on_sample=on_sample, ) - torch.cuda.empty_cache() + self.__gc() def __sample_during_training(self, train_progress: TrainProgress, sample_definitions: list[dict] = None): - torch.cuda.empty_cache() + self.__gc() self.callbacks.on_update_status("sampling") @@ -161,10 +167,10 @@ def __sample_during_training(self, train_progress: TrainProgress, sample_definit self.model_setup.setup_train_device(self.model, self.args) - torch.cuda.empty_cache() + self.__gc() def backup(self): - torch.cuda.empty_cache() + self.__gc() self.callbacks.on_update_status("creating backup") @@ -179,7 +185,7 @@ def backup(self): ) self.model_setup.setup_train_device(self.model, self.args) - torch.cuda.empty_cache() + self.__gc() def __needs_sample(self, train_progress: TrainProgress): return self.action_needed("sample", self.args.sample_after, self.args.sample_after_unit, train_progress) @@ -233,7 +239,7 @@ def train(self): self.model_setup.setup_eval_device(self.model) self.data_loader.ds.start_next_epoch() self.model_setup.setup_train_device(self.model, self.args) - torch.cuda.empty_cache() + self.__gc() current_epoch_length = len(self.data_loader.dl) + train_progress.epoch_step for epoch_step, batch in enumerate(tqdm(self.data_loader.dl, desc="step")): @@ -245,7 +251,7 @@ def train(self): sample_command = self.commands.get_and_reset_sample_command() if sample_command: self.__enqueue_sample_during_training( - lambda: self.__sample_during_training(train_progress, sample_command) + lambda: self.__sample_during_training(train_progress, [sample_command]) ) if not has_gradient: diff --git a/modules/ui/TrainUI.py b/modules/ui/TrainUI.py index 27ec79bd..fc3ebedc 100644 --- a/modules/ui/TrainUI.py +++ b/modules/ui/TrainUI.py @@ -178,6 +178,7 @@ def model_tab(self, master): ("Stable Diffusion 2.0", ModelType.STABLE_DIFFUSION_20), ("Stable Diffusion 2.0 Inpainting", ModelType.STABLE_DIFFUSION_20_INPAINTING), ("Stable Diffusion 2.1", ModelType.STABLE_DIFFUSION_21), + ("Stable Diffusion XL 0.9 Base", ModelType.STABLE_DIFFUSION_XL_10_BASE), ], self.ui_state, "model_type") # output model destination diff --git a/modules/util/args/TrainArgs.py b/modules/util/args/TrainArgs.py index 36428cf0..2416189a 100644 --- a/modules/util/args/TrainArgs.py +++ b/modules/util/args/TrainArgs.py @@ -24,6 +24,7 @@ class TrainArgs: # model settings base_model_name: str extra_model_name: str + weight_dtype: DataType output_dtype: DataType output_model_format: ModelFormat output_model_destination: str @@ -174,6 +175,7 @@ def parse_args() -> 'TrainArgs': # model settings parser.add_argument("--base-model-name", type=str, required=True, dest="base_model_name", help="The base model to start training from") parser.add_argument("--extra-model-name", type=str, required=False, default=None, dest="extra_model_name", help="The extra model to start training from") + parser.add_argument("--weight-dtype", type=DataType, required=False, default=DataType.FLOAT_32, dest="wight_dtype", help="The data type to use for weights during training", choices=list(DataType)) parser.add_argument("--output-dtype", type=DataType, required=True, dest="output_dtype", help="The data type to use for saving weights", choices=list(DataType)) parser.add_argument("--output-model-format", type=ModelFormat, required=False, default=ModelFormat.CKPT, dest="output_model_format", help="The format to save the final output model", choices=list(ModelFormat)) parser.add_argument("--output-model-destination", type=str, required=True, dest="output_model_destination", help="The destination to save the final output model") @@ -255,6 +257,7 @@ def default_values(): # model settings args["base_model_name"] = "runwayml/stable-diffusion-v1-5" args["extra_model_name"] = "" + args["weight_dtype"] = DataType.FLOAT_32 args["output_dtype"] = DataType.FLOAT_32 args["output_model_format"] = ModelFormat.CKPT args["output_model_destination"] = "models/model.ckpt" diff --git a/modules/util/create.py b/modules/util/create.py index 6ca2695e..9d3f3bfb 100644 --- a/modules/util/create.py +++ b/modules/util/create.py @@ -8,6 +8,7 @@ from modules.dataLoader.MgdsStableDiffusionEmbeddingDataLoader import MgdsStableDiffusionEmbeddingDataLoader from modules.dataLoader.MgdsStableDiffusionFineTuneDataLoader import MgdsStableDiffusionFineTuneDataLoader from modules.dataLoader.MgdsStableDiffusionFineTuneVaeDataLoader import MgdsStableDiffusionFineTuneVaeDataLoader +from modules.dataLoader.MgdsStableDiffusionFineXLTuneDataLoader import MgdsStableDiffusionXLFineTuneDataLoader from modules.model.BaseModel import BaseModel from modules.modelLoader.BaseModelLoader import BaseModelLoader from modules.modelLoader.KandinskyLoRAModelLoader import KandinskyLoRAModelLoader @@ -15,16 +16,20 @@ from modules.modelLoader.StableDiffusionEmbeddingModelLoader import StableDiffusionEmbeddingModelLoader from modules.modelLoader.StableDiffusionLoRAModelLoader import StableDiffusionLoRAModelLoader from modules.modelLoader.StableDiffusionModelLoader import StableDiffusionModelLoader +from modules.modelLoader.StableDiffusionXLLoRAModelLoader import StableDiffusionXLLoRAModelLoader +from modules.modelLoader.StableDiffusionXLModelLoader import StableDiffusionXLModelLoader from modules.modelSampler import BaseModelSampler from modules.modelSampler.KandinskySampler import KandinskySampler from modules.modelSampler.StableDiffusionSampler import StableDiffusionSampler from modules.modelSampler.StableDiffusionVaeSampler import StableDiffusionVaeSampler +from modules.modelSampler.StableDiffusionXLSampler import StableDiffusionXLSampler from modules.modelSaver.BaseModelSaver import BaseModelSaver from modules.modelSaver.KandinskyDiffusionModelSaver import KandinskyModelSaver from modules.modelSaver.KandinskyLoRAModelSaver import KandinskyLoRAModelSaver from modules.modelSaver.StableDiffusionEmbeddingModelSaver import StableDiffusionEmbeddingModelSaver from modules.modelSaver.StableDiffusionLoRAModelSaver import StableDiffusionLoRAModelSaver from modules.modelSaver.StableDiffusionModelSaver import StableDiffusionModelSaver +from modules.modelSaver.StableDiffusionXLLoRAModelSaver import StableDiffusionXLLoRAModelSaver from modules.modelSetup.BaseModelSetup import BaseModelSetup from modules.modelSetup.KandinskyFineTuneSetup import KandinskyFineTuneSetup from modules.modelSetup.KandinskyLoRASetup import KandinskyLoRASetup @@ -32,6 +37,8 @@ from modules.modelSetup.StableDiffusionFineTuneSetup import StableDiffusionFineTuneSetup from modules.modelSetup.StableDiffusionFineTuneVaeSetup import StableDiffusionFineTuneVaeSetup from modules.modelSetup.StableDiffusionLoRASetup import StableDiffusionLoRASetup +from modules.modelSetup.StableDiffusionXLFineTuneSetup import StableDiffusionXLFineTuneSetup +from modules.modelSetup.StableDiffusionXLLoRASetup import StableDiffusionXLLoRASetup from modules.module.EMAModule import EMAModuleWrapper from modules.util.TrainProgress import TrainProgress from modules.util.args.TrainArgs import TrainArgs @@ -51,6 +58,8 @@ def create_model_loader( case TrainingMethod.FINE_TUNE: if model_type.is_stable_diffusion(): return StableDiffusionModelLoader() + elif model_type.is_stable_diffusion_xl(): + return StableDiffusionXLModelLoader() elif model_type.is_kandinsky(): return KandinskyModelLoader() case TrainingMethod.FINE_TUNE_VAE: @@ -59,7 +68,9 @@ def create_model_loader( case TrainingMethod.LORA: if model_type.is_stable_diffusion(): return StableDiffusionLoRAModelLoader() - if model_type.is_kandinsky(): + elif model_type.is_stable_diffusion_xl(): + return StableDiffusionXLLoRAModelLoader() + elif model_type.is_kandinsky(): return KandinskyLoRAModelLoader() case TrainingMethod.EMBEDDING: if model_type.is_stable_diffusion(): @@ -82,6 +93,8 @@ def create_model_saver( case TrainingMethod.LORA: if model_type.is_stable_diffusion(): return StableDiffusionLoRAModelSaver() + if model_type.is_stable_diffusion_xl(): + return StableDiffusionXLLoRAModelSaver() if model_type.is_kandinsky(): return KandinskyLoRAModelSaver() case TrainingMethod.EMBEDDING: @@ -100,6 +113,8 @@ def create_model_setup( case TrainingMethod.FINE_TUNE: if model_type.is_stable_diffusion(): return StableDiffusionFineTuneSetup(train_device, temp_device, debug_mode) + if model_type.is_stable_diffusion_xl(): + return StableDiffusionXLFineTuneSetup(train_device, temp_device, debug_mode) elif model_type.is_kandinsky(): return KandinskyFineTuneSetup(train_device, temp_device, debug_mode) case TrainingMethod.FINE_TUNE_VAE: @@ -108,6 +123,8 @@ def create_model_setup( case TrainingMethod.LORA: if model_type.is_stable_diffusion(): return StableDiffusionLoRASetup(train_device, temp_device, debug_mode) + if model_type.is_stable_diffusion_xl(): + return StableDiffusionXLLoRASetup(train_device, temp_device, debug_mode) if model_type.is_kandinsky(): return KandinskyLoRASetup(train_device, temp_device, debug_mode) case TrainingMethod.EMBEDDING: @@ -125,6 +142,8 @@ def create_model_sampler( case TrainingMethod.FINE_TUNE: if model_type.is_stable_diffusion(): return StableDiffusionSampler(model, model_type, train_device) + if model_type.is_stable_diffusion_xl(): + return StableDiffusionXLSampler(model, model_type, train_device) if model_type.is_kandinsky(): return KandinskySampler(model, model_type, train_device) case TrainingMethod.FINE_TUNE_VAE: @@ -133,6 +152,8 @@ def create_model_sampler( case TrainingMethod.LORA: if model_type.is_stable_diffusion(): return StableDiffusionSampler(model, model_type, train_device) + if model_type.is_stable_diffusion_xl(): + return StableDiffusionXLSampler(model, model_type, train_device) if model_type.is_kandinsky(): return KandinskySampler(model, model_type, train_device) case TrainingMethod.EMBEDDING: @@ -151,6 +172,8 @@ def create_data_loader( case TrainingMethod.FINE_TUNE: if model_type.is_stable_diffusion(): return MgdsStableDiffusionFineTuneDataLoader(args, model, train_progress) + if model_type.is_stable_diffusion_xl(): + return MgdsStableDiffusionXLFineTuneDataLoader(args, model, train_progress) elif model_type.is_kandinsky(): return MgdsKandinskyFineTuneDataLoader(args, model, train_progress) case TrainingMethod.FINE_TUNE_VAE: @@ -159,6 +182,8 @@ def create_data_loader( case TrainingMethod.LORA: if model_type.is_stable_diffusion(): return MgdsStableDiffusionFineTuneDataLoader(args, model, train_progress) + if model_type.is_stable_diffusion_xl(): + return MgdsStableDiffusionXLFineTuneDataLoader(args, model, train_progress) if model_type.is_kandinsky(): return MgdsKandinskyFineTuneDataLoader(args, model, train_progress) case TrainingMethod.EMBEDDING: diff --git a/modules/util/enum/ModelType.py b/modules/util/enum/ModelType.py index f6416599..afd934e1 100644 --- a/modules/util/enum/ModelType.py +++ b/modules/util/enum/ModelType.py @@ -11,6 +11,8 @@ class ModelType(Enum): STABLE_DIFFUSION_21 = 'STABLE_DIFFUSION_21' STABLE_DIFFUSION_21_BASE = 'STABLE_DIFFUSION_21_BASE' + STABLE_DIFFUSION_XL_10_BASE = 'STABLE_DIFFUSION_XL_10_BASE' + KANDINSKY_21 = 'KANDINSKY_21' def __str__(self): @@ -26,6 +28,10 @@ def is_stable_diffusion(self): or self == ModelType.STABLE_DIFFUSION_21 \ or self == ModelType.STABLE_DIFFUSION_21_BASE + def is_stable_diffusion_xl(self): + return self == ModelType.STABLE_DIFFUSION_XL_10_BASE + + def is_kandinsky(self): return self == ModelType.KANDINSKY_21 diff --git a/requirements.txt b/requirements.txt index 71713c35..4217af4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,12 +15,13 @@ tensorboard==2.13.0 pytorch-lightning==2.0.3 # stable diffusion -git+https://github.com/huggingface/diffusers.git@f96b760#egg=diffusers +git+https://github.com/huggingface/diffusers.git@78922ed#egg=diffusers git+https://github.com/huggingface/transformers.git@656e869#egg=transformers -omegaconf # needed to load ckpt files +omegaconf==2.3.0 # needed to load stable diffusion from single ckpt files +invisible-watermark==0.2.0 # needed for the SDXL pipeline # data loader -git+https://github.com/Nerogar/mgds.git@e7d758c#egg=mgds +git+https://github.com/Nerogar/mgds.git@fe5adfb#egg=mgds # xformers xformers==0.0.20 diff --git a/training_presets/#sdxl 0.9 LoRA.json b/training_presets/#sdxl 0.9 LoRA.json new file mode 100644 index 00000000..6e809d88 --- /dev/null +++ b/training_presets/#sdxl 0.9 LoRA.json @@ -0,0 +1,64 @@ +{ + "training_method": "LORA", + "debug_mode": false, + "debug_dir": "debug", + "workspace_dir": "workspace/run", + "cache_dir": "workspace-cache/run", + "tensorboard": true, + "model_type": "STABLE_DIFFUSION_XL_10_BASE", + "base_model_name": "stabilityai/stable-diffusion-xl-base-0.9", + "extra_model_name": "", + "output_dtype": "FLOAT_32", + "output_model_format": "CKPT", + "output_model_destination": "models/lora.ckpt", + "concept_file_name": "training_concepts/concepts.json", + "circular_mask_generation": false, + "random_rotate_and_crop": false, + "aspect_ratio_bucketing": true, + "latent_caching": true, + "latent_caching_epochs": 1, + "optimizer": "ADAMW", + "learning_rate_scheduler": "CONSTANT", + "learning_rate": 0.0003, + "learning_rate_warmup_steps": 200, + "learning_rate_cycles": 1, + "weight_decay": 0.01, + "epochs": 100, + "batch_size": 4, + "gradient_accumulation_steps": 1, + "ema": "OFF", + "ema_decay": 0.999, + "ema_update_step_interval": 5, + "train_text_encoder": false, + "train_text_encoder_epochs": 30, + "text_encoder_learning_rate": 0.0003, + "text_encoder_layer_skip": 0, + "train_unet": true, + "train_unet_epochs": 100000, + "unet_learning_rate": 0.0003, + "offset_noise_weight": 0.05, + "rescale_noise_scheduler_to_zero_terminal_snr": false, + "force_v_prediction": false, + "force_epsilon_prediction": false, + "train_device": "cuda", + "temp_device": "cpu", + "train_dtype": "FLOAT_16", + "only_cache": false, + "resolution": 1024, + "masked_training": false, + "unmasked_probability": 0.1, + "unmasked_weight": 0.1, + "normalize_masked_area_loss": false, + "max_noising_strength": 1.0, + "token_count": 1, + "initial_embedding_text": "*", + "lora_rank": 16, + "lora_alpha": 1.0, + "attention_mechanism": "XFORMERS", + "sample_definition_file_name": "training_samples/samples.json", + "sample_after": 10, + "sample_after_unit": "MINUTE", + "backup_after": 30, + "backup_after_unit": "MINUTE", + "backup_before_save": true +} \ No newline at end of file diff --git a/training_presets/#sdxl 0.9.json b/training_presets/#sdxl 0.9.json new file mode 100644 index 00000000..90117073 --- /dev/null +++ b/training_presets/#sdxl 0.9.json @@ -0,0 +1,64 @@ +{ + "training_method": "FINE_TUNE", + "debug_mode": false, + "debug_dir": "debug", + "workspace_dir": "workspace/run", + "cache_dir": "workspace-cache/run", + "tensorboard": true, + "model_type": "STABLE_DIFFUSION_XL_10_BASE", + "base_model_name": "stabilityai/stable-diffusion-xl-base-0.9", + "extra_model_name": "", + "output_dtype": "FLOAT_32", + "output_model_format": "CKPT", + "output_model_destination": "models/model.ckpt", + "concept_file_name": "training_concepts/concepts.json", + "circular_mask_generation": false, + "random_rotate_and_crop": false, + "aspect_ratio_bucketing": true, + "latent_caching": true, + "latent_caching_epochs": 1, + "optimizer": "ADAMW", + "learning_rate_scheduler": "CONSTANT", + "learning_rate": 3e-06, + "learning_rate_warmup_steps": 200, + "learning_rate_cycles": 1, + "weight_decay": 0.01, + "epochs": 100, + "batch_size": 4, + "gradient_accumulation_steps": 1, + "ema": "OFF", + "ema_decay": 0.999, + "ema_update_step_interval": 5, + "train_text_encoder": false, + "train_text_encoder_epochs": 30, + "text_encoder_learning_rate": 3e-06, + "text_encoder_layer_skip": 0, + "train_unet": true, + "train_unet_epochs": 100000, + "unet_learning_rate": 3e-06, + "offset_noise_weight": 0.05, + "rescale_noise_scheduler_to_zero_terminal_snr": false, + "force_v_prediction": false, + "force_epsilon_prediction": false, + "train_device": "cuda", + "temp_device": "cpu", + "train_dtype": "FLOAT_16", + "only_cache": false, + "resolution": 1024, + "masked_training": false, + "unmasked_probability": 0.1, + "unmasked_weight": 0.1, + "normalize_masked_area_loss": false, + "max_noising_strength": 1.0, + "token_count": 1, + "initial_embedding_text": "*", + "lora_rank": 16, + "lora_alpha": 1.0, + "attention_mechanism": "XFORMERS", + "sample_definition_file_name": "training_samples/samples.json", + "sample_after": 10, + "sample_after_unit": "MINUTE", + "backup_after": 30, + "backup_after_unit": "MINUTE", + "backup_before_save": true +} \ No newline at end of file