1
+ import re
1
2
from contextlib import contextmanager
2
3
from typing import Dict , Iterable , Optional , Tuple
3
4
7
8
from invokeai .backend .patches .layers .flux_control_lora_layer import FluxControlLoRALayer
8
9
from invokeai .backend .patches .model_patch_raw import ModelPatchRaw
9
10
from invokeai .backend .patches .pad_with_zeros import pad_with_zeros
11
+ from invokeai .backend .util import InvokeAILogger
10
12
from invokeai .backend .util .devices import TorchDevice
11
13
from invokeai .backend .util .original_weights_storage import OriginalWeightsStorage
12
14
@@ -23,6 +25,7 @@ def apply_smart_model_patches(
23
25
cached_weights : Optional [Dict [str , torch .Tensor ]] = None ,
24
26
force_direct_patching : bool = False ,
25
27
force_sidecar_patching : bool = False ,
28
+ suppress_warning_layers : Optional [re .Pattern ] = None ,
26
29
):
27
30
"""Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each
28
31
module.
@@ -44,6 +47,7 @@ def apply_smart_model_patches(
44
47
dtype = dtype ,
45
48
force_direct_patching = force_direct_patching ,
46
49
force_sidecar_patching = force_sidecar_patching ,
50
+ suppress_warning_layers = suppress_warning_layers ,
47
51
)
48
52
49
53
yield
@@ -70,6 +74,7 @@ def apply_smart_model_patch(
70
74
dtype : torch .dtype ,
71
75
force_direct_patching : bool ,
72
76
force_sidecar_patching : bool ,
77
+ suppress_warning_layers : Optional [re .Pattern ] = None ,
73
78
):
74
79
"""Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct
75
80
patching or a sidecar wrapper for each module.
@@ -89,9 +94,17 @@ def apply_smart_model_patch(
89
94
if not layer_key .startswith (prefix ):
90
95
continue
91
96
92
- module_key , module = LayerPatcher ._get_submodule (
93
- model , layer_key [prefix_len :], layer_key_is_flattened = layer_keys_are_flattened
94
- )
97
+ try :
98
+ module_key , module = LayerPatcher ._get_submodule (
99
+ model , layer_key [prefix_len :], layer_key_is_flattened = layer_keys_are_flattened
100
+ )
101
+ except AttributeError :
102
+ if suppress_warning_layers and suppress_warning_layers .search (layer_key ):
103
+ pass
104
+ else :
105
+ logger = InvokeAILogger .get_logger (LayerPatcher .__name__ )
106
+ logger .warning ("Failed to find module for LoRA layer key: %s" , layer_key )
107
+ continue
95
108
96
109
# Decide whether to use direct patching or a sidecar patch.
97
110
# Direct patching is preferred, because it results in better runtime speed.
0 commit comments