Skip to content

Fuse loras #4473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 60 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
ad0ca34
Fuse loras
patrickvonplaten Aug 4, 2023
697a6a7
initial implementation.
sayakpaul Aug 24, 2023
957f36e
merge into main
sayakpaul Aug 24, 2023
703e9aa
add slow test one.
sayakpaul Aug 24, 2023
f9a7737
styling
sayakpaul Aug 24, 2023
15b7652
add: test for checking efficiency
sayakpaul Aug 24, 2023
b4a9a44
print
sayakpaul Aug 24, 2023
a6d6402
position
sayakpaul Aug 24, 2023
a167a74
place model offload correctly
sayakpaul Aug 24, 2023
14aa423
style
sayakpaul Aug 24, 2023
16311f7
style.
sayakpaul Aug 24, 2023
a355544
unfuse test.
sayakpaul Aug 24, 2023
d8050b5
final checks
sayakpaul Aug 24, 2023
96ae69e
Merge branch 'main' into fuse_loras
patrickvonplaten Aug 25, 2023
fddc586
resolve conflicts.
sayakpaul Aug 28, 2023
a976466
remove warning test
sayakpaul Aug 28, 2023
b718922
remove warnings altogether
sayakpaul Aug 28, 2023
886993e
Merge branch 'main' into fuse_loras
sayakpaul Aug 28, 2023
782367d
debugging
sayakpaul Aug 28, 2023
84f63e8
tighten up tests.
sayakpaul Aug 28, 2023
4a2e6c4
debugging
sayakpaul Aug 28, 2023
1b07e43
debugging
sayakpaul Aug 28, 2023
0d69dde
debugging
sayakpaul Aug 28, 2023
caa79ed
debugging
sayakpaul Aug 28, 2023
345057d
debugging
sayakpaul Aug 28, 2023
f26a62a
debugging
sayakpaul Aug 28, 2023
bc2282c
debugging
sayakpaul Aug 28, 2023
ced2a90
debugging
sayakpaul Aug 28, 2023
ff78a58
debugging
sayakpaul Aug 28, 2023
036a9bc
denugging
sayakpaul Aug 28, 2023
1c6970f
debugging
sayakpaul Aug 28, 2023
9f7492f
debugging
sayakpaul Aug 28, 2023
975adf7
debugging
sayakpaul Aug 28, 2023
0abf6fe
debugging
sayakpaul Aug 28, 2023
de916c6
debugging
sayakpaul Aug 28, 2023
e87e5dd
debugging
sayakpaul Aug 28, 2023
e376b58
debugging
sayakpaul Aug 28, 2023
21f17b0
debugging
sayakpaul Aug 28, 2023
5537305
debugging
sayakpaul Aug 28, 2023
73c07ee
debugging
sayakpaul Aug 28, 2023
dc30b9d
debugging
sayakpaul Aug 28, 2023
96c70e8
debugging
sayakpaul Aug 28, 2023
c5f37b5
debugging
sayakpaul Aug 28, 2023
854231b
debugging
sayakpaul Aug 28, 2023
b00899e
debuging
sayakpaul Aug 28, 2023
788b610
debugging
sayakpaul Aug 28, 2023
aedcb70
debugging
sayakpaul Aug 28, 2023
67b8aa6
debugging
sayakpaul Aug 28, 2023
1c69333
suit up the generator initialization a bit.
sayakpaul Aug 28, 2023
55f5958
remove print
sayakpaul Aug 28, 2023
73bdcb1
update assertion.
sayakpaul Aug 28, 2023
940ed1b
debugging
sayakpaul Aug 28, 2023
8fcd42a
remove print.
sayakpaul Aug 28, 2023
6e99561
fix: assertions.
sayakpaul Aug 28, 2023
c3adb8c
style
sayakpaul Aug 28, 2023
2d6cd03
can generator be a problem?
sayakpaul Aug 28, 2023
53f2e74
generator
sayakpaul Aug 28, 2023
b9ea1fc
correct tests.
sayakpaul Aug 29, 2023
9cb8ec3
support text encoder lora fusion.
sayakpaul Aug 29, 2023
50c611d
tighten up tests.
sayakpaul Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,49 @@ def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=

self.lora_scale = lora_scale

def _fuse_lora(self):
if self.lora_linear_layer is None:
return

dtype, device = self.regular_linear_layer.weight.data.dtype, self.regular_linear_layer.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")

w_orig = self.regular_linear_layer.weight.data.float()
w_up = self.lora_linear_layer.up.weight.data.float()
w_down = self.lora_linear_layer.down.weight.data.float()

if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank

fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
self.lora_linear_layer = None

# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()

def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")

fused_weight = self.regular_linear_layer.weight.data
dtype, device = fused_weight.dtype, fused_weight.device

self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
Comment on lines +120 to +123
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the reason unfuse gives different results might be because we don't do the computation in full fp32 precision here. In the _fuse_lora function we do the computation in full fp32 precision, but here we don't I think. Can we try to make sure that:

unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]

is always computed in full fp32 precision and only then we lower it potentially to fp16 dtype?

Guess we could also easily check this by casting the whole model to fp32 before doing fuse and unfuse to check cc @apolinario

Copy link
Collaborator

@apolinario apolinario Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good hypothesis but would this explain the behaviour of this happening in some LoRAs but not others, and keeping some residual style?

  1. davizca87/vulcan → different image after unloading
  2. ostris/crayon_style_lora_sdxl → same image after unloading
  3. davizca87/sun-flower → different image after unloading (and also different than the image unlaoded in 1)
  4. TheLastBen/Papercut_SDXL→ same image after unloading

Copy link
Collaborator

@apolinario apolinario Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed cases ❌

davizca87/sun-flower seems to keep some of the sunflower vibe in the background

original generation with lora fused after unfusing (residual effects from the LoRA?)
image1 image2 image3

davizca87/vulcan seems to keep some of the vulcan style in the outlines of the robot

original generation with lora fused after unfusing (residual effects from the LoRA?)
image4 image5 image6

Success cases ✅

ostris/crayon_style_lora_sdxl seems to produce a perceptually identical image after unfusing

original generation with lora fused after unfusing (same as original generation)
image4 image image

TheLastBen/Papercut_SDXL and nerijs/pixel-art-xl also exhibit the same correct behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally always worry about how numerical precision stems through a network and affect the end results. Have seen enough cases because of this to not sleep well at night. So, I will start with what @patrickvonplaten suggested.

FWIW, though, there's actually a fast test that ensures the following trip doesn't have side-effects:

load_lora_weights() -> fuse_lora() -> unload_lora_weights() gives you the outputs you would expect after doing fuse_lora().

Let me know if anything is unclear.

Copy link
Collaborator

@apolinario apolinario Aug 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! I guess an analogous test could be made to address what I reported for a future PR

Namely asserting that the two generate with no LoRA are matching:
generate with no LoRA before unfusingload_lora_weights()fuse_lora()unfuse_lora()generate with no LoRA after unfusing

This is the workflow I've reported above, the unfused unet seemingly still contains somewhat of a residue of the lora

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in: #4833 . Was a tricky issue that was caused by the patched text encoder LoRA layers being fully removed when doing unload_lora and therefore loosing their ability to unfuse. Hope it was ok to help out a bit here @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had started fuse-lora-pt2 and mentioned in Slack that I am looking into it. But okay.


self.w_up = None
self.w_down = None

def forward(self, input):
if self.lora_linear_layer is None:
return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)


Expand Down Expand Up @@ -525,6 +567,20 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def fuse_lora(self):
self.apply(self._fuse_lora_apply)

def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora()

def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)

def _unfuse_lora_apply(self, module):
if hasattr(module, "_unfuse_lora"):
module._unfuse_lora()


class TextualInversionLoaderMixin:
r"""
Expand Down Expand Up @@ -1712,6 +1768,83 @@ def unload_lora_weights(self):
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.

<Tip warning={true}>

This is an experimental API.

</Tip>

Args:
fuse_unet (`bool`, defaults to `True`): Whether to fuse the UNet LoRA parameters.
fuse_text_encoder (`bool`, defaults to `True`):
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
if fuse_unet:
self.unet.fuse_lora()

def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora()
attn_module.k_proj._fuse_lora()
attn_module.v_proj._fuse_lora()
attn_module.out_proj._fuse_lora()

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora()
mlp_module.fc2._fuse_lora()

if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2)

def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).

<Tip warning={true}>

This is an experimental API.

</Tip>

Args:
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
if unfuse_unet:
self.unet.unfuse_lora()

def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()

for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()

if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
unfuse_text_encoder_lora(self.text_encoder)
if hasattr(self, "text_encoder_2"):
unfuse_text_encoder_lora(self.text_encoder_2)


class FromSingleFileMixin:
"""
Expand Down
93 changes: 92 additions & 1 deletion src/diffusers/models/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from ..utils import logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
Expand Down Expand Up @@ -91,6 +97,51 @@ def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self):
if self.lora_layer is None:
return

dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")

w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
w_down = self.lora_layer.down.weight.data.float()

if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank

fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + fusion
self.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
self.lora_layer = None

# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()

def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")

fused_weight = self.weight.data
dtype, device = fused_weight.data.dtype, fused_weight.data.device

self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)

fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
fusion = fusion.reshape((fused_weight.shape))
unfused_weight = fused_weight - fusion
self.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
self.w_down = None

def forward(self, x):
if self.lora_layer is None:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
Expand All @@ -109,9 +160,49 @@ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer

def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer

def _fuse_lora(self):
if self.lora_layer is None:
return

dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")

w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
w_down = self.lora_layer.down.weight.data.float()

if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank

fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
self.weight.data = fused_weight.to(device=device, dtype=dtype)

# we can drop the lora layer now
self.lora_layer = None

# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()

def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")

fused_weight = self.weight.data
dtype, device = fused_weight.dtype, fused_weight.device

self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
unfused_weight = fused_weight - torch.bmm(self.w_up[None, :], self.w_down[None, :])[0]
self.weight.data = unfused_weight.to(device=device, dtype=dtype)

self.w_up = None
self.w_down = None

def forward(self, hidden_states, lora_scale: int = 1):
if self.lora_layer is None:
return super().forward(hidden_states)
Expand Down
Loading