Skip to content

Commit 1f63b60

Browse files
hipsterusernamemutatedducks97psychedelicious
authored
Implementing support for Non-Standard LoRA Format (#7985)
* integrate loRA * idk anymore tbh * enable fused matrix for quantized models * integrate loRA * idk anymore tbh * enable fused matrix for quantized models * ruff fix --------- Co-authored-by: Sam <bhaskarmdutt@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
1 parent a499b9f commit 1f63b60

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313

1414
def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor:
1515
"""An optimized implementation of the residual calculation for a sidecar linear LoRALayer."""
16+
# up matrix and down matrix have different ranks so we can't simply multiply them
17+
if lora_layer.up.shape[1] != lora_layer.down.shape[0]:
18+
x = torch.nn.functional.linear(input, lora_layer.get_weight(lora_weight), bias=lora_layer.bias)
19+
x *= lora_weight * lora_layer.scale()
20+
return x
21+
1622
x = torch.nn.functional.linear(input, lora_layer.down)
1723
if lora_layer.mid is not None:
1824
x = torch.nn.functional.linear(x, lora_layer.mid)

invokeai/backend/patches/layers/lora_layer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
self.up = up
2020
self.mid = mid
2121
self.down = down
22+
self.are_ranks_equal = up.shape[1] == down.shape[0]
2223

2324
@classmethod
2425
def from_state_dict_values(
@@ -58,12 +59,42 @@ def from_state_dict_values(
5859
def _rank(self) -> int:
5960
return self.down.shape[0]
6061

62+
def fuse_weights(self, up: torch.Tensor, down: torch.Tensor) -> torch.Tensor:
63+
"""
64+
Fuse the weights of the up and down matrices of a LoRA layer with different ranks.
65+
66+
Since the Huggingface implementation of KQV projections are fused, when we convert to Kohya format
67+
the LoRA weights have different ranks. This function handles the fusion of these differently sized
68+
matrices.
69+
"""
70+
71+
fused_lora = torch.zeros((up.shape[0], down.shape[1]), device=down.device, dtype=down.dtype)
72+
rank_diff = down.shape[0] / up.shape[1]
73+
74+
if rank_diff > 1:
75+
rank_diff = down.shape[0] / up.shape[1]
76+
w_down = down.chunk(int(rank_diff), dim=0)
77+
for w_down_chunk in w_down:
78+
fused_lora = fused_lora + (torch.mm(up, w_down_chunk))
79+
else:
80+
rank_diff = up.shape[1] / down.shape[0]
81+
w_up = up.chunk(int(rank_diff), dim=0)
82+
for w_up_chunk in w_up:
83+
fused_lora = fused_lora + (torch.mm(w_up_chunk, down))
84+
85+
return fused_lora
86+
6187
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
6288
if self.mid is not None:
6389
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
6490
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
6591
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
6692
else:
93+
# up matrix and down matrix have different ranks so we can't simply multiply them
94+
if not self.are_ranks_equal:
95+
weight = self.fuse_weights(self.up, self.down)
96+
return weight
97+
6798
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
6899

69100
return weight

invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@
2020
FLUX_KOHYA_TRANSFORMER_KEY_REGEX = (
2121
r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
2222
)
23+
24+
# A regex pattern that matches all of the last layer keys in the Kohya FLUX LoRA format.
25+
# Example keys:
26+
# lora_unet_final_layer_linear.alpha
27+
# lora_unet_final_layer_linear.lora_down.weight
28+
# lora_unet_final_layer_linear.lora_up.weight
29+
FLUX_KOHYA_LAST_LAYER_KEY_REGEX = r"lora_unet_final_layer_(linear|linear1|linear2)_?(.*)"
30+
2331
# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format.
2432
# Example keys:
2533
# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha
@@ -44,6 +52,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo
4452
"""
4553
return all(
4654
re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k)
55+
or re.match(FLUX_KOHYA_LAST_LAYER_KEY_REGEX, k)
4756
or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k)
4857
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
4958
for k in state_dict.keys()
@@ -65,6 +74,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
6574
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
6675
for layer_name, layer_state_dict in grouped_state_dict.items():
6776
if layer_name.startswith("lora_unet"):
77+
# Skip the final layer. This is incompatible with current model definition.
78+
if layer_name.startswith("lora_unet_final_layer"):
79+
continue
6880
transformer_grouped_sd[layer_name] = layer_state_dict
6981
elif layer_name.startswith("lora_te1"):
7082
clip_grouped_sd[layer_name] = layer_state_dict

0 commit comments

Comments
 (0)