Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Blockwise lora #7352

Merged
merged 40 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8404d7b
Initial commit
UmerHA Mar 15, 2024
84125df
Implemented block lora
UmerHA Mar 16, 2024
7405aff
Finishing up
UmerHA Mar 16, 2024
5c19f18
Reverted unrelated changes made by make style
UmerHA Mar 16, 2024
769f42b
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 16, 2024
8908c90
Fixed typo
UmerHA Mar 16, 2024
d9c55a5
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 16, 2024
7e6ce83
Fixed bug + Made text_encoder_2 scalable
UmerHA Mar 16, 2024
3c841fc
Integrated some review feedback
UmerHA Mar 18, 2024
72b8752
Incorporated review feedback
UmerHA Mar 19, 2024
2247bcb
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 19, 2024
145c7f3
Fix tests
UmerHA Mar 19, 2024
8e26004
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 19, 2024
87e54b4
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 20, 2024
83ff34b
Made every module configurable
UmerHA Mar 20, 2024
5054f02
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 20, 2024
c2395fa
Adapter to new lora test structure
UmerHA Mar 21, 2024
624b2dd
Final cleanup
UmerHA Mar 21, 2024
578e974
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
0b32d64
Some more final fixes
UmerHA Mar 21, 2024
2b4aae6
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 21, 2024
38038b7
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
7411cab
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
3ed3ca5
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
df9df2e
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 22, 2024
24d376f
Make style, quality, fix-copies
UmerHA Mar 22, 2024
8fa6c25
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 22, 2024
7dfa8e3
Updated tutorial;Warning if scale/adapter mismatch
UmerHA Mar 22, 2024
9c6f613
floats are forwarded as-is; changed tutorial scale
UmerHA Mar 23, 2024
a469a4d
make style, quality, fix-copies
UmerHA Mar 23, 2024
957358b
Fixed typo in tutorial
UmerHA Mar 23, 2024
cb062b6
Moved some warnings into `lora_loader_utils.py`
UmerHA Mar 23, 2024
1e61dfb
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 23, 2024
a4a38df
Moved scale/lora mismatch warnings back
UmerHA Mar 24, 2024
9aa1479
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 27, 2024
625045a
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 28, 2024
2939e45
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 29, 2024
14fabf0
Integrated final review suggestions
UmerHA Mar 29, 2024
8500161
Empty commit to trigger CI
UmerHA Mar 29, 2024
74ce9bb
Reverted emoty commit to trigger CI
UmerHA Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Incorporated review feedback
  • Loading branch information
UmerHA committed Mar 19, 2024
commit 72b8752c915fe8a6ba50fdc2a3ae8aa0f46c91d4
31 changes: 16 additions & 15 deletions docs/source/en/using-diffusers/loading_adapters.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,32 +153,33 @@ image
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
</div>

<Tip>
To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:

```py
pipeline.unload_lora_weights()
```

### Adjust LoRA weight scale
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.

For fine-grained control on how much of the LoRA weights are used, use [`~loaders.LoraLoaderMixin.set_adapters`]. Here, you can define scale of any granularity up to per-transformer.
For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying how much to scale the weights in each layer by.
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
```python
pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter")
scales = {
"text_encoder": 0.5,
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # because "mid" is not given, all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.load_lora_weights("my_adapter", scales)
```
</Tip>

To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:

```py
pipeline.unload_lora_weights()
pipe.set_adapters("my_adapter", scales)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
```

### Kohya and TheLastBen
Expand Down
69 changes: 25 additions & 44 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import os
from pathlib import Path
from types import NoneType
from typing import Callable, Dict, List, Optional, Union

import safetensors
Expand Down Expand Up @@ -959,7 +960,7 @@ def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: Optional[Union[float, List[float]]] = None,
text_encoder_weights: Optional[Union[float, List[float], List[NoneType]]] = None,
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Sets the adapter layers for the text encoder.
Expand All @@ -977,17 +978,16 @@ def set_adapters_for_text_encoder(
raise ValueError("PEFT backend is required for this method.")

def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)

weights = [{"text_model": w} if w is not None else {"text_model": 1.0} for w in weights]
weights = [w or 1.0 for w in weights] # Set None values to default of 1.0
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
weights = [{"text_model": w} for w in weights]
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

return weights

Expand Down Expand Up @@ -1036,61 +1036,42 @@ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"]
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[List[float], float, List[Dict], Dict]] = None,
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

allowed_numeric_dtypes = (float, int)
has_second_text_encoder = hasattr(self, "text_encoder_2")
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

# Expand weights into a list, one entry per adapter
if adapter_weights is None or isinstance(adapter_weights, (allowed_numeric_dtypes, dict)):
if not isinstance(adapter_weights, list):
adapter_weights = [adapter_weights] * len(adapter_names)

if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)

# Normalize into dicts
allowed_keys = ["text_encoder", "unet"]
if has_second_text_encoder:
allowed_keys.append("text_encoder_2")

def ensure_is_dict(weight):
if isinstance(weight, dict):
return weight
elif isinstance(weight, allowed_numeric_dtypes):
return {key: weight for key in allowed_keys}
elif weight is None:
return {key: 1.0 for key in allowed_keys}
else:
raise RuntimeError(f"lora weight has wrong type {type(weight)}.")

adapter_weights = [ensure_is_dict(weight) for weight in adapter_weights]
# Decompose weights into weights for unet, text_encoder and text_encoder_2
unet_weights, text_encoder_weights, text_encoder_2_weights = [], [], []
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

UmerHA marked this conversation as resolved.
Show resolved Hide resolved
for weights in adapter_weights:
for k in weights.keys():
if k not in allowed_keys:
raise ValueError(
f"Got invalid key '{k}' in lora weight dict. Allowed keys are 'text_encoder', 'text_encoder_2', 'down', 'mid', 'up'."
)
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
unet_weight = weights.pop("unet", None)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
text_encoder_weight = weights.pop("text_encoder", None)
text_encoder_2_weight = weights.pop("text_encoder_2", None)

# Decompose weights into weights for unet, text_encoder and text_encoder_2
unet_weights, text_encoder_weights = [], []
if has_second_text_encoder:
text_encoder_2_weights = []
if len(weights) >0:
raise ValueError(f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}.")

for weights in adapter_weights:
unet_weight = weights.get("unet", None)
text_encoder_weight = weights.get("text_encoder", None)
if has_second_text_encoder:
text_encoder_2_weight = weights.get("text_encoder_2", None)
if text_encoder_2_weight is not None and not hasattr(self, "text_encoder_2"):
logger.warning("Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2.")
else:
unet_weight = weights
text_encoder_weight = weights
text_encoder_2_weight = weights

unet_weights.append(unet_weight)
text_encoder_weights.append(text_encoder_weight)
if has_second_text_encoder:
text_encoder_2_weights.append(text_encoder_2_weight)
text_encoder_2_weights.append(text_encoder_2_weight)

unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
Expand Down
41 changes: 24 additions & 17 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from types import NoneType
from typing import Callable, Dict, List, Optional, Union

import safetensors
Expand Down Expand Up @@ -562,26 +563,36 @@ def _unfuse_lora_apply(self, module):
module.unmerge()

def _expand_lora_scales_dict(
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
self, scales, blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int]
self, scales: Union[float, Dict], blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int]
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.
Expands the inputs into a more granular dictionary. See the example below for more details.

Parameters:
scales (`Union[float, Dict]`):
Scales dict to expand.
blocks_with_transformer (`Dict[str, int]`):
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
Dict with keys 'up' and 'down', showing which blocks have transformer layers
transformer_per_block (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing how many transformer layers each block has

E.g. turns
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
{
scales = {
'down': 2,
'mid': 3,
'up': {
'block_1': 4,
'block_2': [5, 6, 7]
'block_0': 4,
'block_1': [5, 6, 7]
}
}
blocks_with_transformer = {
'down': [1,2],
'up': [0,1]
}
transformer_per_block = {
'down': 2,
'up': 3
}
into
{
'down.block_1.0': 2,
Expand All @@ -597,15 +608,13 @@ def _expand_lora_scales_dict(
'up.block_1.2': 7,
}
"""
allowed_numeric_dtypes = (float, int)

if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")

if sorted(transformer_per_block.keys()) != ["down", "up"]:
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")

if isinstance(scales, allowed_numeric_dtypes):
if not isinstance(scales, dict):
scales = {o: scales for o in ["down", "mid", "up"]}

if "mid" not in scales:
Expand All @@ -616,13 +625,13 @@ def _expand_lora_scales_dict(
scales[updown] = 1

# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if isinstance(scales[updown], allowed_numeric_dtypes):
if not isinstance(scales[updown], dict):
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}

# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
if isinstance(scales[updown][block], allowed_numeric_dtypes):
if not isinstance(scales[updown][block], dict):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]

# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
Expand Down Expand Up @@ -650,15 +659,15 @@ def layer_name(name):
for layer in scales.keys():
if not any(layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or has not attentions."
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
)

return {layer_name(name): weight for name, weight in scales.items()}

def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float, List[Dict], Dict]] = None,
weights: Optional[Union[float, Dict, List[float], List[Dict], List[NoneType]]] = None,
):
"""
Set the currently active adapters for use in the UNet.
Expand Down Expand Up @@ -691,18 +700,16 @@ def set_adapters(

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, (float, dict)):
# Expand weights into a list, one entry per adapter
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)

if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)

# Set missing value to default of 1.0
weights = [weight or 1.0 for weight in weights]
weights = [weight or 1.0 for weight in weights] # Set None values to default of 1.0
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
blocks_with_transformer = {
"down": [i for i, block in enumerate(self.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(self.up_blocks) if hasattr(block, "attentions")],
Expand Down