Skip to content

Commit

Permalink
Merge pull request #1025 from d8ahazard/dev
Browse files Browse the repository at this point in the history
dev to main [release 0.12.0]
  • Loading branch information
ArrowM authored Mar 6, 2023
2 parents 19d27b6 + 6a0b22b commit 5f7d9b8
Show file tree
Hide file tree
Showing 17 changed files with 309 additions and 165 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ they'll help me help you faster.

[Feature Request](https://github.com/d8ahazard/sd_dreambooth_extension/issues/new?assignees=&labels=&template=feature_request.md&title=)

[Discord](https://discord.gg/q8dtpfRD5w)

# Credits

[Huggingface.co](https://huggingface.co) - All the things
Expand Down
15 changes: 4 additions & 11 deletions dreambooth/dataclasses/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from pydantic import BaseModel

from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names

try:
from extensions.sd_dreambooth_extension.dreambooth import shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_concept import Concept
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_names

except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_concept import Concept # noqa
from dreambooth.dreambooth.utils.image_utils import get_scheduler_names # noqa

# Keys to save, replacing our dumb __init__ method
save_keys = []
Expand Down Expand Up @@ -44,7 +45,6 @@ class DreamboothConfig(BaseModel):
gradient_checkpointing: bool = True
gradient_set_to_none: bool = True
graph_smoothing: int = 50
half_lora: bool = False
half_model: bool = False
train_unfrozen: bool = True
has_ema: bool = False
Expand Down Expand Up @@ -95,6 +95,7 @@ class DreamboothConfig(BaseModel):
save_lora_after: bool = True
save_lora_cancel: bool = False
save_lora_during: bool = True
save_lora_for_extra_net: bool = True
save_preview_every: int = 5
save_safetensors: bool = True
save_state_after: bool = False
Expand Down Expand Up @@ -146,14 +147,6 @@ def __init__(self, model_name: str = "", v2: bool = False, src: str = "",
self.scheduler = "ddim"
self.v2 = v2

# Naive fixes for bad types
if not isinstance(self.lora_model_name, str):
print("Bad lora_model_name found, setting to ''")
self.lora_model_name = ''
if not isinstance(self.stop_text_encoder, float):
print("Bad stop_text_encoder found, setting to 0.0")
self.stop_text_encoder = 0.0

# Actually save as a file
def save(self, backup=False):
"""
Expand Down
4 changes: 1 addition & 3 deletions dreambooth/dataset/class_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,12 @@ def __init__(self, concepts: [Concept], model_dir: str, max_width: int, shuffle:
# Data for new prompts to generate
self.new_prompts = {}
self.required_prompts = 0
# Calculate minimum width
min_width = (int(max_width * 0.28125) // 64) * 64

# Thingy to build prompts
text_getter = FilenameTextGetter(shuffle)

# Create available resolutions
bucket_resos = make_bucket_resolutions(max_width, min_width)
bucket_resos = make_bucket_resolutions(max_width)
c_idx = 0
c_images = {}
i_images = {}
Expand Down
4 changes: 2 additions & 2 deletions dreambooth/dataset/db_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def cache_caption(self, image_path, caption):
self.caption_cache[image_path] = input_ids
return caption, input_ids

def make_buckets_with_caching(self, vae, min_size):
def make_buckets_with_caching(self, vae):
self.vae = vae
self.cache_latents = vae is not None
state = f"Preparing Dataset ({'With Caching' if self.cache_latents else 'Without Caching'})"
print(state)
status.textinfo = state

# Create a list of resolutions
bucket_resos = make_bucket_resolutions(self.resolution, min_size)
bucket_resos = make_bucket_resolutions(self.resolution)
self.train_dict = {}

def sort_images(img_data: List[PromptData], resos, target_dict, is_class_img):
Expand Down
49 changes: 21 additions & 28 deletions dreambooth/diff_to_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

try:
from extensions.sd_dreambooth_extension.dreambooth import shared as shared
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file
from extensions.sd_dreambooth_extension.dreambooth.dataclasses.db_config import from_file, DreamboothConfig
from extensions.sd_dreambooth_extension.dreambooth.shared import status
from extensions.sd_dreambooth_extension.dreambooth.utils.model_utils import unload_system_models, \
reload_system_models, \
Expand All @@ -26,7 +26,7 @@
from extensions.sd_dreambooth_extension.lora_diffusion.lora import merge_lora_to_model
except:
from dreambooth.dreambooth import shared as shared # noqa
from dreambooth.dreambooth.dataclasses.db_config import from_file # noqa
from dreambooth.dreambooth.dataclasses.db_config import from_file, DreamboothConfig # noqa
from dreambooth.dreambooth.shared import status # noqa
from dreambooth.dreambooth.utils.model_utils import unload_system_models, reload_system_models, \
disable_safe_unpickle, enable_safe_unpickle, import_model_class_from_model_name_or_path # noqa
Expand Down Expand Up @@ -338,13 +338,13 @@ def get_model_path(working_dir: str, model_name: str = "", file_extra: str = "")
return None


def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bool = True, log: bool = True,
def compile_checkpoint(model_name: str, lora_file_name: str = None, reload_models: bool = True, log: bool = True,
snap_rev: str = ""):
"""
@param model_name: The model name to compile
@param reload_models: Whether to reload the system list of checkpoints.
@param lora_path: The path to a lora pt file to merge with the unet. Auto set during training.
@param lora_file_name: The path to a lora pt file to merge with the unet. Auto set during training.
@param log: Whether to print messages to console/UI.
@param snap_rev: The revision of snapshot to load from
@return: status: What happened, path: Checkpoint path
Expand All @@ -355,8 +355,8 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
status.job_count = 7

config = from_file(model_name)
if lora_path is None and config.lora_model_name:
lora_path = config.lora_model_name
if lora_file_name is None and config.lora_model_name:
lora_file_name = config.lora_model_name
save_model_name = model_name if config.custom_model_name == "" else config.custom_model_name
if config.custom_model_name == "":
printi(f"Compiling checkpoint for {model_name}...", log=log)
Expand Down Expand Up @@ -418,10 +418,9 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
pass

# Apply LoRA to the unet
if lora_path is not None and lora_path != "":
if lora_file_name is not None and lora_file_name != "":
unet_model = UNet2DConditionModel().from_pretrained(os.path.dirname(unet_path))
lora_rev = apply_lora(unet_model, lora_path, config.lora_unet_rank, config.lora_weight, "cpu", False,
config.use_lora_extended)
lora_rev = apply_lora(config, unet_model, lora_file_name, "cpu", False)
unet_state_dict = copy.deepcopy(unet_model.state_dict())
del unet_model
if lora_rev is not None:
Expand All @@ -448,9 +447,9 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
printi("Converting text encoder...", log=log)

# Apply lora weights to the tenc
if lora_path is not None and lora_path != "":
lora_paths = lora_path.split(".")
lora_txt_path = f"{lora_paths[0]}_txt.{lora_paths[1]}"
if lora_file_name is not None and lora_file_name != "":
lora_paths = lora_file_name.split(".")
lora_txt_file_name = f"{lora_paths[0]}_txt.{lora_paths[1]}"
text_encoder_cls = import_model_class_from_model_name_or_path(config.pretrained_model_name_or_path,
config.revision)

Expand All @@ -461,8 +460,7 @@ def compile_checkpoint(model_name: str, lora_path: str = None, reload_models: bo
torch_dtype=torch.float32
)

apply_lora(text_encoder, lora_txt_path, config.lora_txt_rank, config.lora_txt_weight, "cpu", True,
config.use_lora_extended)
apply_lora(config, text_encoder, lora_txt_file_name, "cpu", True)
text_enc_dict = copy.deepcopy(text_encoder.state_dict())
del text_encoder
else:
Expand Down Expand Up @@ -551,20 +549,15 @@ def load_model(model_path: str, map_location: str):
return loaded


def apply_lora(model: nn.Module, loras: str, rank: int, weight: float, device: str, is_tenc: bool, use_extended: bool):
def apply_lora(config: DreamboothConfig, model: nn.Module, lora_file_name: str, device: str, is_tenc: bool):
lora_rev = None
if loras is not None and loras != "":
if not os.path.exists(loras):
try:
cmd_lora_models_path = shared.lora_models_path
except:
cmd_lora_models_path = None
model_dir = os.path.dirname(cmd_lora_models_path) if cmd_lora_models_path else shared.models_path
loras = os.path.join(model_dir, "lora", loras)

if os.path.exists(loras):
lora_rev = loras.split("_")[-1].replace(".pt", "")
printi(f"Loading lora from {loras}", log=True)
merge_lora_to_model(model, load_model(loras, device), is_tenc, use_extended, rank, weight)
if lora_file_name is not None and lora_file_name != "":
if not os.path.exists(lora_file_name):
lora_file_name = os.path.join(config.model_dir, "loras", lora_file_name)
if os.path.exists(lora_file_name):
lora_rev = lora_file_name.split("_")[-1].replace(".pt", "")
printi(f"Loading lora from {lora_file_name}", log=True)
merge_lora_to_model(model, load_model(lora_file_name, device), is_tenc, config.use_lora_extended,
config.lora_unet_rank, config.lora_weight)

return lora_rev
4 changes: 3 additions & 1 deletion dreambooth/sd_to_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from huggingface_hub import HfApi, hf_hub_download
from omegaconf import OmegaConf

from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_class

try:
from extensions.sd_dreambooth_extension.dreambooth import shared
Expand All @@ -37,12 +36,15 @@
enable_safe_unpickle
from extensions.sd_dreambooth_extension.dreambooth.utils.utils import printi
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.dreambooth.utils.image_utils import get_scheduler_class

except:
from dreambooth.dreambooth import shared # noqa
from dreambooth.dreambooth.dataclasses.db_config import DreamboothConfig # noqa
from dreambooth.dreambooth.utils.model_utils import get_db_models, disable_safe_unpickle, enable_safe_unpickle # noqa
from dreambooth.dreambooth.utils.utils import printi # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.dreambooth.utils.image_utils import get_scheduler_class # noqa

from diffusers import (
AutoencoderKL,
Expand Down
6 changes: 3 additions & 3 deletions dreambooth/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def load_auto_settings():
global models_path, script_path, ckpt_dir, device_id, disable_safe_unpickle, dataset_filename_word_regex, \
dataset_filename_join_string, show_progress_every_n_steps, parallel_processing_allowed, state, ckptfix, medvram, \
lowvram, dreambooth_models_path, lora_models_path, CLIP_stop_at_last_layers, profile_db, debug, config, device, \
lowvram, dreambooth_models_path, ui_lora_models_path, CLIP_stop_at_last_layers, profile_db, debug, config, device, \
force_cpu, embeddings_dir, sd_model
try:
import modules.script_callbacks
Expand Down Expand Up @@ -51,7 +51,7 @@ def set_model(new_model):

try:
dreambooth_models_path = ws.cmd_opts.dreambooth_models_path or dreambooth_models_path
lora_models_path = ws.cmd_opts.lora_models_path or lora_models_path
ui_lora_models_path = ws.cmd_opts.lora_models_path or ui_lora_models_path
embeddings_dir = ws.cmd_opts.embeddings_dir or embeddings_dir
except:
pass
Expand Down Expand Up @@ -293,7 +293,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
embeddings_dir = os.path.join(script_path, "embeddings")
dreambooth_models_path = os.path.join(models_path, "dreambooth")
ckpt_dir = os.path.join(models_path, "Stable-diffusion")
lora_models_path = os.path.join(models_path, "lora")
ui_lora_models_path = os.path.join(models_path, "lora")
db_model_config = None
data_path = os.path.join(script_path, ".cache")
show_progress_every_n_steps = 10
Expand Down
64 changes: 37 additions & 27 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from extensions.sd_dreambooth_extension.dreambooth.xattention import optim_to
from extensions.sd_dreambooth_extension.helpers.ema_model import EMAModel
from extensions.sd_dreambooth_extension.helpers.mytqdm import mytqdm
from extensions.sd_dreambooth_extension.lora_diffusion.extra_networks import save_extra_networks
from extensions.sd_dreambooth_extension.lora_diffusion.lora import save_lora_weight, \
TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module
from extensions.sd_dreambooth_extension.dreambooth.deis_velocity import get_velocity
Expand All @@ -67,6 +68,7 @@
from dreambooth.dreambooth.xattention import optim_to # noqa
from dreambooth.helpers.ema_model import EMAModel # noqa
from dreambooth.helpers.mytqdm import mytqdm # noqa
from dreambooth.lora_diffusion.extra_networks import save_extra_networks # noqa
from dreambooth.lora_diffusion.lora import save_lora_weight, TEXT_ENCODER_DEFAULT_TARGET_REPLACE, get_target_module # noqa

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -241,21 +243,12 @@ def create_vae():
vae = create_vae()
printm("Created vae")

try:
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
torch_dtype=torch.float32
)
except:
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
torch_dtype=torch.float16
)
unet = unet.to(dtype=torch.float32)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="unet",
revision=args.revision,
torch_dtype=torch.float32
)
unet = torch2ify(unet)

# Check that all trainable models are in full precision
Expand Down Expand Up @@ -329,7 +322,7 @@ def create_vae():
if args.use_lora:
unet.requires_grad_(False)
if args.lora_model_name:
lora_path = os.path.join(shared.models_path, "lora", args.lora_model_name)
lora_path = os.path.join(args.model_dir, "loras", args.lora_model_name)
lora_txt = lora_path.replace(".pt", "_txt.pt")

if not os.path.exists(lora_path) or not os.path.isfile(lora_path):
Expand Down Expand Up @@ -716,8 +709,8 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
requires_safety_checker=None
)
scheduler_class = get_scheduler_class(args.scheduler)
s_pipeline.unet = torch2ify(s_pipeline.unet)
s_pipeline.enable_attention_slicing()
s_pipeline.unet = torch2ify(s_pipeline.unet)
xformerify(s_pipeline)

s_pipeline.scheduler = scheduler_class.from_config(s_pipeline.scheduler.config)
Expand All @@ -735,6 +728,7 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo
pbar.update()
try:
out_file = None
# Loras resume from pt
if not args.use_lora:
if save_snapshot:
pbar.set_description("Saving Snapshot")
Expand All @@ -756,27 +750,43 @@ def save_weights(save_image, save_model, save_snapshot, save_checkpoint, save_lo

elif save_lora:
pbar.set_description("Saving Lora Weights...")
# setup directory
loras_dir = os.path.join(args.model_dir, "loras")
os.makedirs(loras_dir, exist_ok=True)
# setup pt path
lora_model_name = args.model_name if args.custom_model_name == "" else args.custom_model_name
model_dir = os.path.dirname(shared.lora_models_path)
out_file = os.path.join(model_dir, "lora")
os.makedirs(out_file, exist_ok=True)
out_file = os.path.join(out_file, f"{lora_model_name}_{args.revision}.pt")
lora_file_prefix = f"{lora_model_name}_{args.revision}"
out_file = os.path.join(loras_dir, f"{lora_file_prefix}.pt")
# create pt
tgt_module = get_target_module("module", args.use_lora_extended)
d_type = torch.float16 if args.half_lora else torch.float32
save_lora_weight(s_pipeline.unet, out_file, tgt_module)

save_lora_weight(s_pipeline.unet, out_file, tgt_module, d_type=d_type)
modelmap = {"unet": (s_pipeline.unet, tgt_module)}
# save text_encoder
if stop_text_percentage != 0:
out_txt = out_file.replace(".pt", "_txt.pt")
modelmap["text_encoder"] = (s_pipeline.text_encoder, TEXT_ENCODER_DEFAULT_TARGET_REPLACE)
save_lora_weight(s_pipeline.text_encoder,
out_txt,
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
d_type=d_type)
target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE)
pbar.update()

# save extra_net
if args.save_lora_for_extra_net:
if args.use_lora_extended:
import sys
has_locon = len([path for path in sys.path if 'a1111-sd-webui-locon' in path]) > 0
if not has_locon:
raise Exception(r"a1111-sd-webui-locon extension is required to save "
r"extra net for extended lora. Please install "
r"https://github.com/KohakuBlueleaf/a1111-sd-webui-locon")
os.makedirs(shared.ui_lora_models_path, exist_ok=True)
out_safe = os.path.join(shared.ui_lora_models_path, f"{lora_file_prefix}.safetensors")
save_extra_networks(modelmap, out_safe)
# package pt into checkpoint
if save_checkpoint:
pbar.set_description("Compiling Checkpoint")
snap_rev = str(args.revision) if save_snapshot else ""
compile_checkpoint(args.model_name, reload_models=False, lora_path=out_file, log=False,
compile_checkpoint(args.model_name, reload_models=False, lora_file_name=out_file, log=False,
snap_rev=snap_rev)
pbar.update()
printm("Restored, moved to acc.device.")
Expand Down
Loading

0 comments on commit 5f7d9b8

Please sign in to comment.