Skip to content

[LoRA Attn Processors] Refactor LoRA Attn Processors #4765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
308d510
[LoRA Attn] Refactor LoRA attn
patrickvonplaten Aug 24, 2023
06d5859
correct for network alphas
patrickvonplaten Aug 24, 2023
ed47c6c
fix more
patrickvonplaten Aug 24, 2023
ec29980
fix more tests
patrickvonplaten Aug 24, 2023
1116254
fix more tests
patrickvonplaten Aug 24, 2023
ff35325
Move below
patrickvonplaten Aug 24, 2023
b5339fa
Finish
patrickvonplaten Aug 24, 2023
e40f212
better version
patrickvonplaten Aug 25, 2023
4c80f2a
Merge branch 'main' into refactor_lora_attn
patrickvonplaten Aug 25, 2023
6edf0fa
correct serialization format
patrickvonplaten Aug 25, 2023
3724fbf
Merge branch 'main' into refactor_lora_attn
patrickvonplaten Aug 25, 2023
d5b6514
fix
patrickvonplaten Aug 25, 2023
2f669f4
Merge branch 'refactor_lora_attn' of https://github.com/huggingface/d…
patrickvonplaten Aug 25, 2023
45dce9f
fix more
patrickvonplaten Aug 25, 2023
2f207b8
fix more
patrickvonplaten Aug 25, 2023
00aca18
fix more
patrickvonplaten Aug 25, 2023
54fb5eb
Apply suggestions from code review
patrickvonplaten Aug 25, 2023
efe35c1
Merge branch 'main' into refactor_lora_attn
patrickvonplaten Aug 25, 2023
34150a7
Update src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_…
patrickvonplaten Aug 25, 2023
a69e2bd
deprecation
patrickvonplaten Aug 25, 2023
7c5a3de
relax atol for slow test slighly
patrickvonplaten Aug 25, 2023
402de4b
Finish tests
patrickvonplaten Aug 25, 2023
c904423
make style
patrickvonplaten Aug 25, 2023
dfcada5
Merge branch 'main' into refactor_lora_attn
patrickvonplaten Aug 26, 2023
7154084
Merge branch 'main' into refactor_lora_attn
patrickvonplaten Aug 26, 2023
d9373a4
make style
patrickvonplaten Aug 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 66 additions & 136 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import requests
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download, model_info
from torch import nn

Expand Down Expand Up @@ -231,15 +230,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

"""
from .models.attention_processor import (
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
CustomDiffusionAttnProcessor,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
SlicedAttnAddedKVProcessor,
XFormersAttnProcessor,
)
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer

Expand Down Expand Up @@ -314,24 +305,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
state_dict = pretrained_model_name_or_path_or_dict

# fill attn processors
attn_processors = {}
non_attn_lora_layers = []
lora_layers_list = []

is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())

if is_lora:
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
)
if is_new_lora_format:
# Strip the `"unet"` prefix.
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
warnings.warn(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
# correct keys
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)

lora_grouped_dict = defaultdict(dict)
mapped_network_alphas = {}
Expand Down Expand Up @@ -367,87 +348,38 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
if "lora.down.weight" in value_dict:
rank = value_dict["lora.down.weight"].shape[0]

if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size

lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")

value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
non_attn_lora_layers.append((attn_processor, lora))
rank = value_dict["lora.down.weight"].shape[0]

if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size

lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
# To handle SDXL.
rank_mapping = {}
hidden_size_mapping = {}
for projection_id in ["to_k", "to_q", "to_v", "to_out"]:
rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0]
hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0]

rank_mapping.update({f"{projection_id}_lora.down.weight": rank})
hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size})

if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
else:
attn_processor_class = (
LoRAAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else LoRAAttnProcessor
)

if attn_processor_class is not LoRAAttnAddedKVProcessor:
attn_processors[key] = attn_processor_class(
rank=rank_mapping.get("to_k_lora.down.weight"),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
q_rank=rank_mapping.get("to_q_lora.down.weight"),
q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"),
v_rank=rank_mapping.get("to_v_lora.down.weight"),
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
out_rank=rank_mapping.get("to_out_lora.down.weight"),
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
)
else:
attn_processors[key] = attn_processor_class(
rank=rank_mapping.get("to_k_lora.down.weight", None),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
)
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")

attn_processors[key].load_state_dict(value_dict)
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
lora_layers_list.append((attn_processor, lora))

elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
if len(value) == 0:
Expand Down Expand Up @@ -475,22 +407,47 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)

self.set_attn_processor(attn_processors)
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
)

# set correct dtype & device
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
non_attn_lora_layers = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in non_attn_lora_layers]

# set layers
self.set_attn_processor(attn_processors)
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]

# set ff layers
for target_module, lora_layer in non_attn_lora_layers:
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)

def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
)
if is_new_lora_format:
# Strip the `"unet"` prefix.
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
logger.warn(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}

# change processor format to 'pure' LoRACompatibleLinear format
if any("processor" in k.split(".") for k in state_dict.keys()):

def format_to_lora_compatible(key):
if "processor" not in key.split("."):
return key
return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")

state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}

if network_alphas is not None:
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
return state_dict, network_alphas

def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
Expand Down Expand Up @@ -1748,36 +1705,9 @@ def unload_lora_weights(self):
>>> ...
```
"""
from .models.attention_processor import (
LORA_ATTENTION_PROCESSORS,
AttnProcessor,
AttnProcessor2_0,
LoRAAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)

unet_attention_classes = {type(processor) for _, processor in self.unet.attn_processors.items()}

if unet_attention_classes.issubset(LORA_ATTENTION_PROCESSORS):
# Handle attention processors that are a mix of regular attention and AddedKV
# attention.
if len(unet_attention_classes) > 1 or LoRAAttnAddedKVProcessor in unet_attention_classes:
self.unet.set_default_attn_processor()
else:
regular_attention_classes = {
LoRAAttnProcessor: AttnProcessor,
LoRAAttnProcessor2_0: AttnProcessor2_0,
LoRAXFormersAttnProcessor: XFormersAttnProcessor,
}
[attention_proc_class] = unet_attention_classes
self.unet.set_attn_processor(regular_attention_classes[attention_proc_class]())

for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
Expand Down
Loading