Skip to content

[LoRA] support kohya and xlabs loras for flux. #9295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 293 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import re

import torch

from ..utils import is_peft_version, logging


Expand Down Expand Up @@ -326,3 +328,294 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
return {new_name: alpha}


# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")

# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here

# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up

def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]

# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / sd_lora_rank

# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2

down_weight = down_weight * scale_down
up_weight = up_weight * scale_up

# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]

# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")

# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})

# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416

# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]

def _convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mod_lin",
f"transformer.transformer_blocks.{i}.norm1.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mod_lin",
f"transformer.transformer_blocks.{i}.norm1_context.linear",
)

for i in range(38):
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.proj_out",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_modulation_lin",
f"transformer.single_transformer_blocks.{i}.norm.linear",
)

if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")

return ait_sd

return _convert_sd_scripts_to_ai_toolkit(state_dict)


# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
# Some utilities were reused from
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
new_state_dict = {}
orig_keys = list(old_state_dict.keys())

def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
down_weight = sds_sd.pop(sds_key)
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))

# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]

# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]

# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})

# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416

for old_key in orig_keys:
# Handle double_blocks
if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.transformer_blocks.{block_num}"

if "processor.proj_lora1" in old_key:
new_key += ".attn.to_out.0"
elif "processor.proj_lora2" in old_key:
new_key += ".attn.to_add_out"
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
handle_qkv(
old_state_dict,
new_state_dict,
old_key,
[
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
],
)
# continue
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
handle_qkv(
old_state_dict,
new_state_dict,
old_key,
[
f"transformer.transformer_blocks.{block_num}.attn.to_q",
f"transformer.transformer_blocks.{block_num}.attn.to_k",
f"transformer.transformer_blocks.{block_num}.attn.to_v",
],
)
# continue

if "down" in old_key:
new_key += ".lora_A.weight"
elif "up" in old_key:
new_key += ".lora_B.weight"

# Handle single_blocks
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.single_transformer_blocks.{block_num}"

if "proj_lora1" in old_key or "proj_lora2" in old_key:
new_key += ".proj_out"
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
new_key += ".norm.linear"

if "down" in old_key:
new_key += ".lora_A.weight"
elif "up" in old_key:
new_key += ".lora_B.weight"

else:
# Handle other potential key patterns here
new_key = old_key

# Since we already handle qkv above.
if "qkv" not in old_key:
new_state_dict[new_key] = old_state_dict.pop(old_key)

if len(old_state_dict) > 0:
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")

return new_state_dict
21 changes: 20 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
scale_lora_layers,
)
from .lora_base import LoraBaseMixin
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)


if is_transformers_available():
Expand Down Expand Up @@ -1583,6 +1588,20 @@ def lora_state_dict(
allow_pickle=allow_pickle,
)

# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.

is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return (state_dict, None) if return_alphas else state_dict

is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return (state_dict, None) if return_alphas else state_dict

# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
Expand Down
Loading