Skip to content

Commit 10c8910

Browse files
author
Kevin Turner
committed
feat(LoRA): allow LoRA layer patcher to continue past unknown layers
1 parent f351ad4 commit 10c8910

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

invokeai/backend/patches/layer_patcher.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
from contextlib import contextmanager
23
from typing import Dict, Iterable, Optional, Tuple
34

@@ -7,6 +8,7 @@
78
from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer
89
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
910
from invokeai.backend.patches.pad_with_zeros import pad_with_zeros
11+
from invokeai.backend.util import InvokeAILogger
1012
from invokeai.backend.util.devices import TorchDevice
1113
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
1214

@@ -23,6 +25,7 @@ def apply_smart_model_patches(
2325
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
2426
force_direct_patching: bool = False,
2527
force_sidecar_patching: bool = False,
28+
suppress_warning_layers: Optional[re.Pattern] = None,
2629
):
2730
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
2831
module.
@@ -44,6 +47,7 @@ def apply_smart_model_patches(
4447
dtype=dtype,
4548
force_direct_patching=force_direct_patching,
4649
force_sidecar_patching=force_sidecar_patching,
50+
suppress_warning_layers=suppress_warning_layers,
4751
)
4852

4953
yield
@@ -70,6 +74,7 @@ def apply_smart_model_patch(
7074
dtype: torch.dtype,
7175
force_direct_patching: bool,
7276
force_sidecar_patching: bool,
77+
suppress_warning_layers: Optional[re.Pattern] = None,
7378
):
7479
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
7580
patching or a sidecar wrapper for each module.
@@ -89,9 +94,17 @@ def apply_smart_model_patch(
8994
if not layer_key.startswith(prefix):
9095
continue
9196

92-
module_key, module = LayerPatcher._get_submodule(
93-
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
94-
)
97+
try:
98+
module_key, module = LayerPatcher._get_submodule(
99+
model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened
100+
)
101+
except AttributeError:
102+
if suppress_warning_layers and suppress_warning_layers.search(layer_key):
103+
pass
104+
else:
105+
logger = InvokeAILogger.get_logger(LayerPatcher.__name__)
106+
logger.warning("Failed to find module for LoRA layer key: %s", layer_key)
107+
continue
95108

96109
# Decide whether to use direct patching or a sidecar patch.
97110
# Direct patching is preferred, because it results in better runtime speed.

0 commit comments

Comments
 (0)