Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implements Blockwise lora #7352

Merged
merged 40 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8404d7b
Initial commit
UmerHA Mar 15, 2024
84125df
Implemented block lora
UmerHA Mar 16, 2024
7405aff
Finishing up
UmerHA Mar 16, 2024
5c19f18
Reverted unrelated changes made by make style
UmerHA Mar 16, 2024
769f42b
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 16, 2024
8908c90
Fixed typo
UmerHA Mar 16, 2024
d9c55a5
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 16, 2024
7e6ce83
Fixed bug + Made text_encoder_2 scalable
UmerHA Mar 16, 2024
3c841fc
Integrated some review feedback
UmerHA Mar 18, 2024
72b8752
Incorporated review feedback
UmerHA Mar 19, 2024
2247bcb
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 19, 2024
145c7f3
Fix tests
UmerHA Mar 19, 2024
8e26004
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 19, 2024
87e54b4
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 20, 2024
83ff34b
Made every module configurable
UmerHA Mar 20, 2024
5054f02
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 20, 2024
c2395fa
Adapter to new lora test structure
UmerHA Mar 21, 2024
624b2dd
Final cleanup
UmerHA Mar 21, 2024
578e974
Merge branch 'huggingface:main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
0b32d64
Some more final fixes
UmerHA Mar 21, 2024
2b4aae6
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 21, 2024
38038b7
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
7411cab
Merge remote-tracking branch 'upstream/main' into 7231-blockwise-lora
UmerHA Mar 21, 2024
3ed3ca5
Update using_peft_for_inference.md
UmerHA Mar 21, 2024
df9df2e
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 22, 2024
24d376f
Make style, quality, fix-copies
UmerHA Mar 22, 2024
8fa6c25
Merge branch '7231-blockwise-lora' of https://github.com/UmerHA/diffu…
UmerHA Mar 22, 2024
7dfa8e3
Updated tutorial;Warning if scale/adapter mismatch
UmerHA Mar 22, 2024
9c6f613
floats are forwarded as-is; changed tutorial scale
UmerHA Mar 23, 2024
a469a4d
make style, quality, fix-copies
UmerHA Mar 23, 2024
957358b
Fixed typo in tutorial
UmerHA Mar 23, 2024
cb062b6
Moved some warnings into `lora_loader_utils.py`
UmerHA Mar 23, 2024
1e61dfb
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 23, 2024
a4a38df
Moved scale/lora mismatch warnings back
UmerHA Mar 24, 2024
9aa1479
Merge branch 'main' into 7231-blockwise-lora
UmerHA Mar 27, 2024
625045a
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 28, 2024
2939e45
Merge branch 'main' into 7231-blockwise-lora
sayakpaul Mar 29, 2024
14fabf0
Integrated final review suggestions
UmerHA Mar 29, 2024
8500161
Empty commit to trigger CI
UmerHA Mar 29, 2024
74ce9bb
Reverted emoty commit to trigger CI
UmerHA Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix tests
  • Loading branch information
UmerHA committed Mar 19, 2024
commit 145c7f30719c29fbbb0f9cbd22a15287dd16f122
19 changes: 11 additions & 8 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,6 @@ def set_adapters(
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):

adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names

# Expand weights into a list, one entry per adapter
Expand All @@ -1059,15 +1058,19 @@ def set_adapters(
text_encoder_weight = weights.pop("text_encoder", None)
text_encoder_2_weight = weights.pop("text_encoder_2", None)

if len(weights) >0:
raise ValueError(f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}.")
if len(weights) > 0:
raise ValueError(
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved

if text_encoder_2_weight is not None and not hasattr(self, "text_encoder_2"):
logger.warning("Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2.")
if text_encoder_2_weight is not None and not hasattr(self, "text_encoder_2"):
logger.warning(
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
)
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
else:
unet_weight = weights
text_encoder_weight = weights
text_encoder_2_weight = weights
unet_weight = weights
text_encoder_weight = weights
text_encoder_2_weight = weights

unet_weights.append(unet_weight)
text_encoder_weights.append(text_encoder_weight)
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
from collections import defaultdict
Expand Down Expand Up @@ -563,7 +564,10 @@ def _unfuse_lora_apply(self, module):
module.unmerge()

def _expand_lora_scales_dict(
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
UmerHA marked this conversation as resolved.
Show resolved Hide resolved
self, scales: Union[float, Dict], blocks_with_transformer: Dict[str, int], transformer_per_block: Dict[str, int]
self,
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.
Expand Down Expand Up @@ -614,6 +618,8 @@ def _expand_lora_scales_dict(
if sorted(transformer_per_block.keys()) != ["down", "up"]:
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")

scales = copy.deepcopy(scales)

if not isinstance(scales, dict):
scales = {o: scales for o in ["down", "mid", "up"]}

Expand All @@ -631,7 +637,7 @@ def _expand_lora_scales_dict(
# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
if not isinstance(scales[updown][block], dict):
if not isinstance(scales[updown][block], list):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]

# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
Expand Down
19 changes: 5 additions & 14 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,15 +847,15 @@ def test_simple_inference_with_text_unet_block_scale(self):
)

weights_1 = {
"unet" : {
"unet": {
"down": 5,
}
}
pipe.set_adapters("adapter-1", weights_1)
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0)).images

weights_2 = {
"unet" : {
"unet": {
"up": 5,
}
}
Expand Down Expand Up @@ -915,17 +915,8 @@ def test_simple_inference_with_text_unet_multi_adapter_block_lora(self):
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)

scales_1 = {
"unet" : {
"down": 5
}
}
scales_2 = {
"unet" : {
"down": 5,
"mid": 5
}
}
scales_1 = {"unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}}
pipe.set_adapters("adapter-1", scales_1)

output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
Expand Down Expand Up @@ -990,7 +981,7 @@ def updown_options(blocks_with_tf, layers_per_block, value):
def all_possible_dict_opts(unet, value):
"""
Generate every possible combination for how a lora weight dict can be.
E.g. 2, {"down": 2}, {"down": [2,2,2]}, {"mid": 2, "up": [2,2,2]}, ...
E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
"""

down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
Expand Down