Skip to content

Commit

Permalink
load Diffusers format, check schnell/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 6, 2024
1 parent ba08a89 commit 83e3048
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 111 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ The command to install PyTorch is as follows:

### Recent Updates

Oct 6, 2024:
- In FLUX.1 LoRA training and fine-tuning, the specified weight file (*.safetensors) is automatically determined to be dev or schnell. This allows schnell models to be loaded correctly. Note that LoRA training with schnell models and fine-tuning with schnell models are unverified.
- FLUX.1 LoRA training and fine-tuning can now load weights in Diffusers format in addition to BFL format (a single *.safetensors file). Please specify the parent directory of `transformer` or `diffusion_pytorch_model-00001-of-00003.safetensors` with the full path. However, Diffusers format CLIP/T5XXL is not supported. Saving is supported only in BFL format.

Sep 26, 2024:
The implementation of block swap during FLUX.1 fine-tuning has been changed to improve speed about 10% (depends on the environment). A new `--blocks_to_swap` option has been added, and `--double_blocks_to_swap` and `--single_blocks_to_swap` are deprecated. `--double_blocks_to_swap` and `--single_blocks_to_swap` are working as before, but they will be removed in the future. See [FLUX.1 fine-tuning](#flux1-fine-tuning) for details.

Expand Down
15 changes: 6 additions & 9 deletions flux_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,6 @@ def encode(prpt: str):
steps = args.steps
guidance_scale = args.guidance

name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
is_schnell = name == "schnell"

def is_fp8(dt):
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]

Expand Down Expand Up @@ -455,12 +452,8 @@ def is_fp8(dt):
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)

t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()

# DiT
model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device)
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
Expand All @@ -469,8 +462,12 @@ def is_fp8(dt):
# if args.offload:
# model = model.to("cpu")

t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()

# AE
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
ae.eval()
# if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae)
Expand Down
15 changes: 7 additions & 8 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,16 @@ def train(args):

train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認

_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512)
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))

Expand Down Expand Up @@ -177,12 +177,11 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"

# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
Expand All @@ -196,7 +195,7 @@ def train(args):

# prepare tokenize strategy
if args.t5xxl_max_token_length is None:
if name == "schnell":
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
Expand Down Expand Up @@ -258,8 +257,8 @@ def train(args):
clean_memory_on_device(accelerator.device)

# load FLUX
flux = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)

if args.gradient_checkpointing:
Expand Down Expand Up @@ -294,7 +293,7 @@ def train(args):

if not cache_latents:
# load VAE here if not cached
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)
Expand Down
17 changes: 7 additions & 10 deletions flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import copy
import math
import random
from typing import Any
from typing import Any, Optional

import torch
from accelerate import Accelerator
Expand All @@ -24,6 +24,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None

def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
Expand Down Expand Up @@ -57,19 +58,15 @@ def assert_extra_args(self, args, train_dataset_group):

train_dataset_group.verify_bucket_reso_steps(32) # TODO check this

def get_flux_model_name(self, args):
return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"

def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
name = self.get_flux_model_name(args)

# if the file is fp8 and we are using fp8_base, we can load it as is (fp8)
loading_dtype = None if args.fp8_base else weight_dtype

# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
if args.fp8_base:
# check dtype of model
Expand Down Expand Up @@ -100,7 +97,7 @@ def load_target_model(self, args, weight_dtype, accelerator):
elif t5xxl.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 T5XXL model")

ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)

return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model

Expand Down Expand Up @@ -142,10 +139,10 @@ def prepare_split_model(self, model, weight_dtype, accelerator):
return flux_lower

def get_tokenize_strategy(self, args):
name = self.get_flux_model_name(args)
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)

if args.t5xxl_max_token_length is None:
if name == "schnell":
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
Expand Down
178 changes: 170 additions & 8 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from typing import Optional, Union
import os
from typing import List, Optional, Tuple, Union
import einops
import torch

from safetensors.torch import load_file
from safetensors import safe_open
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config

Expand All @@ -17,6 +19,8 @@
logger = logging.getLogger(__name__)

MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"


# temporary copy from sd3_utils TODO refactor
Expand All @@ -39,29 +43,64 @@ def load_safetensors(
return load_file(path) # prevent device invalid Error


def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]:
# check the state dict: Diffusers or BFL, dev or schnell
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")

if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
ckpt_path = os.path.join(ckpt_path, "transformer", "diffusion_pytorch_model-00001-of-00003.safetensors")
if "00001-of-00003" in ckpt_path:
ckpt_paths = [ckpt_path.replace("00001-of-00003", f"0000{i}-of-00003") for i in range(1, 4)]
else:
ckpt_paths = [ckpt_path]

keys = []
for ckpt_path in ckpt_paths:
with safe_open(ckpt_path, framework="pt") as f:
keys.extend(f.keys())

is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
return is_diffusers, is_schnell, ckpt_paths


def load_flow_model(
name: str, ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.Flux:
logger.info(f"Building Flux model {name}")
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL

# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params)
if dtype is not None:
model = model.to(dtype)

# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))

# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd)
logger.info("Converted Diffusers to BFL")

info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return model
return is_schnell, model


def load_ae(
name: str, ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
) -> flux_models.AutoEncoder:
logger.info("Building AutoEncoder")
with torch.device("meta"):
ae = flux_models.AutoEncoder(flux_models.configs[name].ae_params).to(dtype)
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs[MODEL_NAME_DEV].ae_params).to(dtype)

logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
Expand Down Expand Up @@ -246,3 +285,126 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x


# region Diffusers

NUM_DOUBLE_BLOCKS = 19
NUM_SINGLE_BLOCKS = 38

BFL_TO_DIFFUSERS_MAP = {
"time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
"time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
"time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
"time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
"vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
"vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
"vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
"vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
"guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
"guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
"guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
"guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
"txt_in.weight": ["context_embedder.weight"],
"txt_in.bias": ["context_embedder.bias"],
"img_in.weight": ["x_embedder.weight"],
"img_in.bias": ["x_embedder.bias"],
"double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
"double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
"double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
"double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
"double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
"double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
"double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
"double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
"double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
"double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
"double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
"double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
"double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
"double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
"double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
"double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
"double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
"double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
"double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
"double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
"double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
"double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
"double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
"double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
"single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
"single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
"single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
"single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
"single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
"single_blocks.().linear2.weight": ["proj_out.weight"],
"single_blocks.().linear2.bias": ["proj_out.bias"],
"final_layer.linear.weight": ["proj_out.weight"],
"final_layer.linear.bias": ["proj_out.bias"],
"final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
"final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
}


def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
# make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(NUM_DOUBLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(NUM_SINGLE_BLOCKS):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if not (key.startswith("double_blocks.") or key.startswith("single_blocks.")):
for i, weight in enumerate(weights):
diffusers_to_bfl_map[weight] = (i, key)
return diffusers_to_bfl_map


def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map()

# iterate over three safetensors files to reduce memory usage
flux_sd = {}
for diffusers_key, tensor in diffusers_sd.items():
if diffusers_key in diffusers_to_bfl_map:
index, bfl_key = diffusers_to_bfl_map[diffusers_key]
if bfl_key not in flux_sd:
flux_sd[bfl_key] = []
flux_sd[bfl_key].append((index, tensor))
else:
logger.error(f"Error: Key not found in diffusers_to_bfl_map: {diffusers_key}")
raise KeyError(f"Key not found in diffusers_to_bfl_map: {diffusers_key}")

# concat tensors if multiple tensors are mapped to a single key, sort by index
for key, values in flux_sd.items():
if len(values) == 1:
flux_sd[key] = values[0][1]
else:
flux_sd[key] = torch.cat([value[1] for value in sorted(values, key=lambda x: x[0])])

# special case for final_layer.adaLN_modulation.1.weight and final_layer.adaLN_modulation.1.bias
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight

if "final_layer.adaLN_modulation.1.weight" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.weight"])
if "final_layer.adaLN_modulation.1.bias" in flux_sd:
flux_sd["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(flux_sd["final_layer.adaLN_modulation.1.bias"])

return flux_sd


# endregion
Loading

0 comments on commit 83e3048

Please sign in to comment.