forked from Nerogar/OneTrainer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
work in progress on SDXL LoRA training
- Loading branch information
Showing
24 changed files
with
1,757 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
/venv* | ||
/debug* | ||
/workspace* | ||
/models* | ||
/training_concepts | ||
/training_samples | ||
debug.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
14 changes: 14 additions & 0 deletions
14
modules/dataLoader/MgdsStableDiffusionFineXLTuneDataLoader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from modules.dataLoader.MgdsStableDiffusionXLBaseDataLoader import MgdsStablDiffusionXLBaseDataLoader | ||
from modules.model.StableDiffusionXLModel import StableDiffusionXLModel | ||
from modules.util.TrainProgress import TrainProgress | ||
from modules.util.args.TrainArgs import TrainArgs | ||
|
||
|
||
class MgdsStableDiffusionXLFineTuneDataLoader(MgdsStablDiffusionXLBaseDataLoader): | ||
def __init__( | ||
self, | ||
args: TrainArgs, | ||
model: StableDiffusionXLModel, | ||
train_progress: TrainProgress, | ||
): | ||
super(MgdsStableDiffusionXLFineTuneDataLoader, self).__init__(args, model, train_progress) |
376 changes: 376 additions & 0 deletions
376
modules/dataLoader/MgdsStableDiffusionXLBaseDataLoader.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, EulerDiscreteScheduler, \ | ||
StableDiffusionXLPipeline | ||
from transformers import CLIPTextModel, CLIPTokenizer | ||
|
||
from modules.model.BaseModel import BaseModel | ||
from modules.module.LoRAModule import LoRAModuleWrapper | ||
from modules.util.TrainProgress import TrainProgress | ||
from modules.util.enum.ModelType import ModelType | ||
|
||
|
||
class StableDiffusionXLModel(BaseModel): | ||
# base model data | ||
model_type: ModelType | ||
tokenizer_1: CLIPTokenizer | ||
tokenizer_2: CLIPTokenizer | ||
noise_scheduler: EulerDiscreteScheduler | ||
text_encoder_1: CLIPTextModel | ||
text_encoder_2: CLIPTextModel | ||
vae: AutoencoderKL | ||
unet: UNet2DConditionModel | ||
|
||
# persistent training data | ||
text_encoder_1_lora: LoRAModuleWrapper | None | ||
text_encoder_2_lora: LoRAModuleWrapper | None | ||
unet_lora: LoRAModuleWrapper | None | ||
|
||
def __init__( | ||
self, | ||
model_type: ModelType, | ||
tokenizer_1: CLIPTokenizer, | ||
tokenizer_2: CLIPTokenizer, | ||
noise_scheduler: EulerDiscreteScheduler, | ||
text_encoder_1: CLIPTextModel, | ||
text_encoder_2: CLIPTextModel, | ||
vae: AutoencoderKL, | ||
unet: UNet2DConditionModel, | ||
optimizer_state_dict: dict | None = None, | ||
ema_state_dict: dict | None = None, | ||
train_progress: TrainProgress = None, | ||
text_encoder_1_lora: LoRAModuleWrapper | None = None, | ||
text_encoder_2_lora: LoRAModuleWrapper | None = None, | ||
unet_lora: LoRAModuleWrapper | None = None, | ||
): | ||
super(StableDiffusionXLModel, self).__init__(model_type, optimizer_state_dict, ema_state_dict, train_progress) | ||
|
||
self.tokenizer_1 = tokenizer_1 | ||
self.tokenizer_2 = tokenizer_2 | ||
self.noise_scheduler = noise_scheduler | ||
self.text_encoder_1 = text_encoder_1 | ||
self.text_encoder_2 = text_encoder_2 | ||
self.vae = vae | ||
self.unet = unet | ||
|
||
self.text_encoder_1_lora = text_encoder_1_lora | ||
self.text_encoder_2_lora = text_encoder_2_lora | ||
self.unet_lora = unet_lora | ||
|
||
def create_pipeline(self) -> DiffusionPipeline: | ||
return StableDiffusionXLPipeline( | ||
vae=self.vae, | ||
text_encoder=self.text_encoder_1, | ||
text_encoder_2=self.text_encoder_2, | ||
tokenizer=self.tokenizer_1, | ||
tokenizer_2=self.tokenizer_2, | ||
unet=self.unet, | ||
scheduler=self.noise_scheduler, | ||
) |
121 changes: 121 additions & 0 deletions
121
modules/modelLoader/StableDiffusionXLLoRAModelLoader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import json | ||
import os | ||
|
||
import torch | ||
from safetensors.torch import load_file | ||
from torch import Tensor | ||
|
||
from modules.model.StableDiffusionXLModel import StableDiffusionXLModel | ||
from modules.modelLoader.BaseModelLoader import BaseModelLoader | ||
from modules.modelLoader.StableDiffusionXLModelLoader import StableDiffusionXLModelLoader | ||
from modules.module.LoRAModule import LoRAModuleWrapper | ||
from modules.util.TrainProgress import TrainProgress | ||
from modules.util.enum.ModelType import ModelType | ||
|
||
|
||
class StableDiffusionXLLoRAModelLoader(BaseModelLoader): | ||
def __init__(self): | ||
super(StableDiffusionXLLoRAModelLoader, self).__init__() | ||
|
||
@staticmethod | ||
def __get_rank(state_dict: dict) -> int: | ||
for name, state in state_dict.items(): | ||
if "lora_down.weight" in name: | ||
return state.shape[0] | ||
|
||
@staticmethod | ||
def __init_lora(model: StableDiffusionXLModel, state_dict: dict[str, Tensor]): | ||
rank = StableDiffusionXLLoRAModelLoader.__get_rank(state_dict) | ||
|
||
model.text_encoder_1_lora = LoRAModuleWrapper( | ||
orig_module=model.text_encoder_1, | ||
rank=rank, | ||
prefix="lora_te", | ||
).to(dtype=torch.float32) | ||
model.text_encoder_1_lora.load_state_dict(state_dict) | ||
|
||
model.text_encoder_2_lora = LoRAModuleWrapper( | ||
orig_module=model.text_encoder_2, | ||
rank=rank, | ||
prefix="lora_te_2", | ||
).to(dtype=torch.float32) | ||
model.text_encoder_2_lora.load_state_dict(state_dict) | ||
|
||
model.unet_lora = LoRAModuleWrapper( | ||
orig_module=model.unet, | ||
rank=rank, | ||
prefix="lora_unet", | ||
module_filter=["attentions"], | ||
).to(dtype=torch.float32) | ||
model.unet_lora.load_state_dict(state_dict) | ||
|
||
@staticmethod | ||
def __load_safetensors(model: StableDiffusionXLModel, lora_name: str) -> bool: | ||
try: | ||
state_dict = load_file(lora_name) | ||
StableDiffusionXLLoRAModelLoader.__init_lora(model, state_dict) | ||
return True | ||
except: | ||
return False | ||
|
||
@staticmethod | ||
def __load_ckpt(model: StableDiffusionXLModel, lora_name: str) -> bool: | ||
try: | ||
state_dict = torch.load(lora_name) | ||
StableDiffusionXLLoRAModelLoader.__init_lora(model, state_dict) | ||
return True | ||
except: | ||
return False | ||
|
||
@staticmethod | ||
def __load_internal(model: StableDiffusionXLModel, lora_name: str) -> bool: | ||
try: | ||
with open(os.path.join(lora_name, "meta.json"), "r") as meta_file: | ||
meta = json.load(meta_file) | ||
train_progress = TrainProgress( | ||
epoch=meta['train_progress']['epoch'], | ||
epoch_step=meta['train_progress']['epoch_step'], | ||
epoch_sample=meta['train_progress']['epoch_sample'], | ||
global_step=meta['train_progress']['global_step'], | ||
) | ||
|
||
# embedding model | ||
loaded = StableDiffusionXLLoRAModelLoader.__load_ckpt( | ||
model, | ||
os.path.join(lora_name, "lora", "lora.pt") | ||
) | ||
if not loaded: | ||
return False | ||
|
||
# optimizer | ||
try: | ||
model.optimizer_state_dict = torch.load(os.path.join(lora_name, "optimizer", "optimizer.pt")) | ||
except FileNotFoundError: | ||
pass | ||
|
||
# ema | ||
try: | ||
model.ema_state_dict = torch.load(os.path.join(lora_name, "ema", "ema.pt")) | ||
except FileNotFoundError: | ||
pass | ||
|
||
# meta | ||
model.train_progress = train_progress | ||
|
||
return True | ||
except: | ||
return False | ||
|
||
def load(self, model_type: ModelType, base_model_name: str, extra_model_name: str | None) -> StableDiffusionXLModel | None: | ||
base_model_loader = StableDiffusionXLModelLoader() | ||
model = base_model_loader.load(model_type, base_model_name, None) | ||
|
||
lora_loaded = self.__load_internal(model, extra_model_name) | ||
|
||
if not lora_loaded: | ||
lora_loaded = self.__load_ckpt(model, extra_model_name) | ||
|
||
if not lora_loaded: | ||
lora_loaded = self.__load_safetensors(model, extra_model_name) | ||
|
||
return model |
Oops, something went wrong.