Skip to content

Commit

Permalink
work in progress on SDXL LoRA training
Browse files Browse the repository at this point in the history
  • Loading branch information
Nerogar committed Jul 16, 2023
1 parent 10261e6 commit 59a325c
Show file tree
Hide file tree
Showing 24 changed files with 1,757 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/venv*
/debug*
/workspace*
/models*
/training_concepts
/training_samples
debug.py
Expand Down
2 changes: 1 addition & 1 deletion modules/dataLoader/MgdsKandinskyBaseDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _load_input_modules(self, args: TrainArgs, model: KandinskyModel) -> list:
select_prompt_input = SelectInput(setting_name='concept.prompt_source', out_name='prompts', setting_to_in_name_map={
'sample': 'sample_prompts',
'concept': 'concept_prompts',
'filename': 'filename_prompts',
'filename': 'filename_prompt',
}, default_in_name='sample_prompts')
select_random_text = SelectRandomText(texts_in_name='prompts', text_out_name='prompt')

Expand Down
2 changes: 0 additions & 2 deletions modules/dataLoader/MgdsStableDiffusionFineTuneDataLoader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from mgds.GenericDataLoaderModules import *

from modules.dataLoader.MgdsStableDiffusionBaseDataLoader import MgdsStablDiffusionBaseDataLoader
from modules.model.StableDiffusionModel import StableDiffusionModel
from modules.util.TrainProgress import TrainProgress
Expand Down
14 changes: 14 additions & 0 deletions modules/dataLoader/MgdsStableDiffusionFineXLTuneDataLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from modules.dataLoader.MgdsStableDiffusionXLBaseDataLoader import MgdsStablDiffusionXLBaseDataLoader
from modules.model.StableDiffusionXLModel import StableDiffusionXLModel
from modules.util.TrainProgress import TrainProgress
from modules.util.args.TrainArgs import TrainArgs


class MgdsStableDiffusionXLFineTuneDataLoader(MgdsStablDiffusionXLBaseDataLoader):
def __init__(
self,
args: TrainArgs,
model: StableDiffusionXLModel,
train_progress: TrainProgress,
):
super(MgdsStableDiffusionXLFineTuneDataLoader, self).__init__(args, model, train_progress)
376 changes: 376 additions & 0 deletions modules/dataLoader/MgdsStableDiffusionXLBaseDataLoader.py

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions modules/model/StableDiffusionXLModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, EulerDiscreteScheduler, \
StableDiffusionXLPipeline
from transformers import CLIPTextModel, CLIPTokenizer

from modules.model.BaseModel import BaseModel
from modules.module.LoRAModule import LoRAModuleWrapper
from modules.util.TrainProgress import TrainProgress
from modules.util.enum.ModelType import ModelType


class StableDiffusionXLModel(BaseModel):
# base model data
model_type: ModelType
tokenizer_1: CLIPTokenizer
tokenizer_2: CLIPTokenizer
noise_scheduler: EulerDiscreteScheduler
text_encoder_1: CLIPTextModel
text_encoder_2: CLIPTextModel
vae: AutoencoderKL
unet: UNet2DConditionModel

# persistent training data
text_encoder_1_lora: LoRAModuleWrapper | None
text_encoder_2_lora: LoRAModuleWrapper | None
unet_lora: LoRAModuleWrapper | None

def __init__(
self,
model_type: ModelType,
tokenizer_1: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
noise_scheduler: EulerDiscreteScheduler,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
optimizer_state_dict: dict | None = None,
ema_state_dict: dict | None = None,
train_progress: TrainProgress = None,
text_encoder_1_lora: LoRAModuleWrapper | None = None,
text_encoder_2_lora: LoRAModuleWrapper | None = None,
unet_lora: LoRAModuleWrapper | None = None,
):
super(StableDiffusionXLModel, self).__init__(model_type, optimizer_state_dict, ema_state_dict, train_progress)

self.tokenizer_1 = tokenizer_1
self.tokenizer_2 = tokenizer_2
self.noise_scheduler = noise_scheduler
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
self.vae = vae
self.unet = unet

self.text_encoder_1_lora = text_encoder_1_lora
self.text_encoder_2_lora = text_encoder_2_lora
self.unet_lora = unet_lora

def create_pipeline(self) -> DiffusionPipeline:
return StableDiffusionXLPipeline(
vae=self.vae,
text_encoder=self.text_encoder_1,
text_encoder_2=self.text_encoder_2,
tokenizer=self.tokenizer_1,
tokenizer_2=self.tokenizer_2,
unet=self.unet,
scheduler=self.noise_scheduler,
)
121 changes: 121 additions & 0 deletions modules/modelLoader/StableDiffusionXLLoRAModelLoader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json
import os

import torch
from safetensors.torch import load_file
from torch import Tensor

from modules.model.StableDiffusionXLModel import StableDiffusionXLModel
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.StableDiffusionXLModelLoader import StableDiffusionXLModelLoader
from modules.module.LoRAModule import LoRAModuleWrapper
from modules.util.TrainProgress import TrainProgress
from modules.util.enum.ModelType import ModelType


class StableDiffusionXLLoRAModelLoader(BaseModelLoader):
def __init__(self):
super(StableDiffusionXLLoRAModelLoader, self).__init__()

@staticmethod
def __get_rank(state_dict: dict) -> int:
for name, state in state_dict.items():
if "lora_down.weight" in name:
return state.shape[0]

@staticmethod
def __init_lora(model: StableDiffusionXLModel, state_dict: dict[str, Tensor]):
rank = StableDiffusionXLLoRAModelLoader.__get_rank(state_dict)

model.text_encoder_1_lora = LoRAModuleWrapper(
orig_module=model.text_encoder_1,
rank=rank,
prefix="lora_te",
).to(dtype=torch.float32)
model.text_encoder_1_lora.load_state_dict(state_dict)

model.text_encoder_2_lora = LoRAModuleWrapper(
orig_module=model.text_encoder_2,
rank=rank,
prefix="lora_te_2",
).to(dtype=torch.float32)
model.text_encoder_2_lora.load_state_dict(state_dict)

model.unet_lora = LoRAModuleWrapper(
orig_module=model.unet,
rank=rank,
prefix="lora_unet",
module_filter=["attentions"],
).to(dtype=torch.float32)
model.unet_lora.load_state_dict(state_dict)

@staticmethod
def __load_safetensors(model: StableDiffusionXLModel, lora_name: str) -> bool:
try:
state_dict = load_file(lora_name)
StableDiffusionXLLoRAModelLoader.__init_lora(model, state_dict)
return True
except:
return False

@staticmethod
def __load_ckpt(model: StableDiffusionXLModel, lora_name: str) -> bool:
try:
state_dict = torch.load(lora_name)
StableDiffusionXLLoRAModelLoader.__init_lora(model, state_dict)
return True
except:
return False

@staticmethod
def __load_internal(model: StableDiffusionXLModel, lora_name: str) -> bool:
try:
with open(os.path.join(lora_name, "meta.json"), "r") as meta_file:
meta = json.load(meta_file)
train_progress = TrainProgress(
epoch=meta['train_progress']['epoch'],
epoch_step=meta['train_progress']['epoch_step'],
epoch_sample=meta['train_progress']['epoch_sample'],
global_step=meta['train_progress']['global_step'],
)

# embedding model
loaded = StableDiffusionXLLoRAModelLoader.__load_ckpt(
model,
os.path.join(lora_name, "lora", "lora.pt")
)
if not loaded:
return False

# optimizer
try:
model.optimizer_state_dict = torch.load(os.path.join(lora_name, "optimizer", "optimizer.pt"))
except FileNotFoundError:
pass

# ema
try:
model.ema_state_dict = torch.load(os.path.join(lora_name, "ema", "ema.pt"))
except FileNotFoundError:
pass

# meta
model.train_progress = train_progress

return True
except:
return False

def load(self, model_type: ModelType, base_model_name: str, extra_model_name: str | None) -> StableDiffusionXLModel | None:
base_model_loader = StableDiffusionXLModelLoader()
model = base_model_loader.load(model_type, base_model_name, None)

lora_loaded = self.__load_internal(model, extra_model_name)

if not lora_loaded:
lora_loaded = self.__load_ckpt(model, extra_model_name)

if not lora_loaded:
lora_loaded = self.__load_safetensors(model, extra_model_name)

return model
Loading

0 comments on commit 59a325c

Please sign in to comment.