Skip to content

Commit

Permalink
Update FLUX.1 support for compact models
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 12, 2024
1 parent ecaea90 commit e277b57
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 18 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@ The command to install PyTorch is as follows:

### Recent Updates

Oct 12, 2024 (update 1):

- [Experimental] FLUX.1 fine-tuning and LoRA training now support "FLUX.1 __compact__" models.
- A compact model is a model that retains the FLUX.1 architecture but reduces the number of double/single blocks from the default 19/38.
- The model is automatically determined based on the keys in *.safetensors.
- Specifications for compact model safetensors:
- Please specify the block indices as consecutive numbers. An error will occur if there are missing numbers. For example, if you reduce the double blocks to 15, the maximum key will be `double_blocks.14.*`. The same applies to single blocks.
- LoRA training is unverified.
- The trained model can be used for inference with `flux_minimal_inference.py`. Other inference environments are unverified.

Oct 12, 2024:

- Multi-GPU training now works on Windows. Thanks to Akegarasu for PR [#1686](https://github.com/kohya-ss/sd-scripts/pull/1686)!
Expand Down
12 changes: 6 additions & 6 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def train(args):

train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認

_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
Expand Down Expand Up @@ -181,7 +181,7 @@ def train(args):
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
Expand Down Expand Up @@ -510,8 +510,8 @@ def wait_blocks_move(block_idx, futures):
library.adafactor_fused.patch_adafactor_fused(optimizer)

blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()

Expand Down Expand Up @@ -603,8 +603,8 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
parameter_optimizer_map = {}

blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2

n = 1 # only asynchronous purpose, no need to increase this number
Expand Down
2 changes: 1 addition & 1 deletion flux_train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def prepare_split_model(self, model, weight_dtype, accelerator):
return flux_lower

def get_tokenize_strategy(self, args):
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)

if args.t5xxl_max_token_length is None:
if is_schnell:
Expand Down
76 changes: 65 additions & 11 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import replace
import json
import os
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -43,8 +44,21 @@ def load_safetensors(
return load_file(path) # prevent device invalid Error


def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool, List[str]]:
# check the state dict: Diffusers or BFL, dev or schnell
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
Args:
ckpt_path (str): チェックポイントファイルまたはディレクトリのパス。
Returns:
Tuple[bool, bool, Tuple[int, int], List[str]]:
- bool: Diffusersかどうかを示すフラグ。
- bool: Schnellかどうかを示すフラグ。
- Tuple[int, int]: ダブルブロックとシングルブロックの数。
- List[str]: チェックポイントに含まれるキーのリスト。
"""
# check the state dict: Diffusers or BFL, dev or schnell, number of blocks
logger.info(f"Checking the state dict: Diffusers or BFL, dev or schnell")

if os.path.isdir(ckpt_path): # if ckpt_path is a directory, it is Diffusers
Expand All @@ -61,19 +75,57 @@ def check_flux_state_dict_diffusers_schnell(ckpt_path: str) -> Tuple[bool, bool,

is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
return is_diffusers, is_schnell, ckpt_paths

# check number of double and single blocks
if not is_diffusers:
max_double_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("double_blocks.") and key.endswith(".img_attn.proj.bias")]
)
max_single_block_index = max(
[int(key.split(".")[1]) for key in keys if key.startswith("single_blocks.") and key.endswith(".modulation.lin.bias")]
)
else:
max_double_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("transformer_blocks.") and key.endswith(".attn.add_k_proj.bias")
]
)
max_single_block_index = max(
[
int(key.split(".")[1])
for key in keys
if key.startswith("single_transformer_blocks.") and key.endswith(".attn.to_k.bias")
]
)

num_double_blocks = max_double_block_index + 1
num_single_blocks = max_single_block_index + 1

return is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths


def load_flow_model(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, ckpt_paths = check_flux_state_dict_diffusers_schnell(ckpt_path)
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL

# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
model = flux_models.Flux(flux_models.configs[name].params)
params = flux_models.configs[name].params

# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)

model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)

Expand All @@ -86,7 +138,7 @@ def load_flow_model(
# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd)
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")

info = model.load_state_dict(sd, strict=False, assign=True)
Expand Down Expand Up @@ -349,16 +401,16 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
}


def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
def make_diffusers_to_bfl_map(num_double_blocks: int, num_single_blocks: int) -> dict[str, tuple[int, str]]:
# make reverse map from diffusers map
diffusers_to_bfl_map = {} # key: diffusers_key, value: (index, bfl_key)
for b in range(NUM_DOUBLE_BLOCKS):
for b in range(num_double_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("double_blocks."):
block_prefix = f"transformer_blocks.{b}."
for i, weight in enumerate(weights):
diffusers_to_bfl_map[f"{block_prefix}{weight}"] = (i, key.replace("()", f"{b}"))
for b in range(NUM_SINGLE_BLOCKS):
for b in range(num_single_blocks):
for key, weights in BFL_TO_DIFFUSERS_MAP.items():
if key.startswith("single_blocks."):
block_prefix = f"single_transformer_blocks.{b}."
Expand All @@ -371,8 +423,10 @@ def make_diffusers_to_bfl_map() -> dict[str, tuple[int, str]]:
return diffusers_to_bfl_map


def convert_diffusers_sd_to_bfl(diffusers_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map()
def convert_diffusers_sd_to_bfl(
diffusers_sd: dict[str, torch.Tensor], num_double_blocks: int = NUM_DOUBLE_BLOCKS, num_single_blocks: int = NUM_SINGLE_BLOCKS
) -> dict[str, torch.Tensor]:
diffusers_to_bfl_map = make_diffusers_to_bfl_map(num_double_blocks, num_single_blocks)

# iterate over three safetensors files to reduce memory usage
flux_sd = {}
Expand Down

0 comments on commit e277b57

Please sign in to comment.