Skip to content

Loading LoRA weights fails for OneTrainer Flux LoRAs #10972

Closed
@spezialspezial

Description

@spezialspezial

Describe the bug

Loading OneTrainer style LoRAs, using diffusers commit #dcd77ce22273708294b7b9c2f7f0a4e45d7a9f33, fails with error:

Traceback (most recent call last):
  File "/+DEV/diffusers-edge/src/diffusers/loaders/lora_pipeline.py", line 1527, in load_lora_weights
    state_dict, network_alphas = self.lora_state_dict(
  File "/+DEVTOOL/miniconda3/envs/flux/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/+DEV/diffusers-edge/src/diffusers/loaders/lora_pipeline.py", line 1450, in lora_state_dict
    state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
  File "/+DEV/diffusers-edge/src/diffusers/loaders/lora_conversion_utils.py", line 687, in _convert_kohya_flux_lora_to_diffusers
    return _convert_mixture_state_dict_to_diffusers(state_dict)
  File "/+DEV/diffusers-edge/src/diffusers/loaders/lora_conversion_utils.py", line 659, in _convert_mixture_state_dict_to_diffusers
    if remaining_all_unet:
UnboundLocalError: local variable 'remaining_all_unet' referenced before assignment

Basic OneTrainer LoRA structure:

"onetrainer": {
	"transformer_name": "lora_transformer_",
	"double_block_name": "transformer_blocks_",
	"single_block_name": "single_transformer_blocks_",
	"double_module_names": (
		"_attn_to_out_0",
		("_attn_to_q", "_attn_to_k", "_attn_to_v"),
		"_ff_net_0_proj",
		"_ff_net_2",
		"_norm1_linear",
		"_attn_to_add_out",
		("_attn_add_q_proj", "_attn_add_k_proj", "_attn_add_v_proj"),
		"_ff_context_net_0_proj", "_ff_context_net_2",
		"_norm1_context_linear"
	),
	"single_module_names": (
		("_attn_to_q", "_attn_to_k", "_attn_to_v", "_proj_mlp"),
		"_proj_out",
		"_norm_linear",
	),
	"param_names": (".lora_down.weight", ".lora_up.weight", ".alpha"),
	"dora_param_name": ".dora_scale",
	"text_encoder_names": ("lora_te1_", "lora_te2_"),
	"unique_meta": ("ot_branch", "ot_revision", "ot_config"),
	"comment": "kohya-diffusers mix-ish, supports modelspec, yay"
},

Example LoRAs:

https://civitai.com/models/767016?modelVersionId=857899
https://civitai.com/models/794095?modelVersionId=887953
https://civitai.com/models/754969?modelVersionId=884632
https://civitai.com/models/991928?modelVersionId=1111315
https://civitai.com/models/825919?modelVersionId=923640
https://civitai.com/models/1226276?modelVersionId=1381683

Somewhat related: #10954

Reproduction

from pathlib import Path
import torch
from diffusers import FluxTransformer2DModel, TorchAoConfig, FluxPipeline
from transformers import T5EncoderModel

repo_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
quantization_config = TorchAoConfig("int8_weight_only")

transformer = FluxTransformer2DModel.from_pretrained(
    repo_id,
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)

text_encoder_2 = T5EncoderModel.from_pretrained(
    repo_id,
    subfolder="text_encoder_2",
    quantization_config=quantization_config,
    torch_dtype=dtype,
)

pipe = FluxPipeline.from_pretrained(
  repo_id,
  transformer=transformer,
  text_encoder_2=text_encoder_2,
  torch_dtype=dtype,
)

lora_path = Path("/-LoRAs/Flux/charcoal3000.safetensors")
pipe.load_lora_weights(lora_path, adapter_name=lora_path.stem)

Logs

System Info

diffusers dcd77ce, Linux like everyone, and python3.10

Who can help?

Calling LoRA ambassador Mr. @sayakpaul

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions