Skip to content

Commit

Permalink
refactor SD3 CLIP to transformers etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 24, 2024
1 parent 138dac4 commit 623017f
Show file tree
Hide file tree
Showing 13 changed files with 1,130 additions and 2,079 deletions.
4 changes: 2 additions & 2 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler

import library.train_util as train_util

Expand Down Expand Up @@ -241,7 +241,7 @@ def train(args):

text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()

prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
Expand Down
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def cache_text_encoder_outputs_if_needed(
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()

prompts = sd3_train_utils.load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
Expand Down
3 changes: 1 addition & 2 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from safetensors.torch import save_file

from library import flux_models, flux_utils, strategy_base, train_util
from library.sd3_train_utils import load_prompts
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()
Expand Down Expand Up @@ -70,7 +69,7 @@ def sample_images(
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])

prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)

save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
Expand Down
59 changes: 28 additions & 31 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,21 @@
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config

from library import flux_models

from library.utils import setup_logging, MemoryEfficientSafeOpen
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

from library import flux_models
from library.utils import load_safetensors

MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"


# temporary copy from sd3_utils TODO refactor
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
return load_file(path, device=device)
except:
return load_file(path) # prevent device invalid Error


def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
Expand Down Expand Up @@ -161,8 +142,14 @@ def load_ae(
return ae


def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel:
logger.info("Building CLIP")
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> CLIPTextModel:
logger.info("Building CLIP-L")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
"architectures": ["CLIPModel"],
Expand Down Expand Up @@ -255,15 +242,22 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
with init_empty_weights():
clip = CLIPTextModel._from_config(config)

logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP: {info}")
logger.info(f"Loaded CLIP-L: {info}")
return clip


def load_t5xxl(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
Expand Down Expand Up @@ -303,8 +297,11 @@ def load_t5xxl(
with init_empty_weights():
t5xxl = T5EncoderModel._from_config(config)

logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded T5xxl: {info}")
return t5xxl
Expand Down
9 changes: 3 additions & 6 deletions library/sai_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3-medium"
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"

Expand Down Expand Up @@ -140,10 +140,7 @@ def build_metadata(
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
if sd3 == "m":
arch = ARCH_SD3_M
else:
arch = ARCH_SD3_UNKNOWN
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
Expand Down
Loading

0 comments on commit 623017f

Please sign in to comment.