Skip to content
13 changes: 10 additions & 3 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,20 @@ def load_lora_adapter(
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}

# create LoraConfig
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)

# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)

# create LoraConfig
lora_config = _create_lora_config(
state_dict,
network_alphas,
metadata,
rank,
model_state_dict=self.state_dict(),
adapter_name=adapter_name,
)

# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
Expand Down
57 changes: 49 additions & 8 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)


def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
def get_peft_kwargs(
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
Expand Down Expand Up @@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
else:
lora_alpha = set(network_alpha_dict.values()).pop()

# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
Expand All @@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
"use_dora": use_dora,
"lora_bias": lora_bias,
}

# Example: try load FusionX LoRA into Wan VACE
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
if exclude_modules:
if not is_peft_version(">=", "0.14.0"):
msg = """
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
https://github.com/huggingface/diffusers/issues/new
"""
logger.debug(msg)
else:
lora_config_kwargs.update({"exclude_modules": exclude_modules})

return lora_config_kwargs


Expand Down Expand Up @@ -294,19 +310,20 @@ def check_peft_version(min_version: str) -> None:


def _create_lora_config(
state_dict,
network_alphas,
metadata,
rank_pattern_dict,
is_unet: bool = True,
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
from peft import LoraConfig

if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
rank_pattern_dict,
network_alpha_dict=network_alphas,
peft_state_dict=state_dict,
is_unet=is_unet,
model_state_dict=model_state_dict,
adapter_name=adapter_name,
)

_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
Expand Down Expand Up @@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):

if warn_msg:
logger.warning(warn_msg)


def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
"""
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
doesn't exist in `peft_state_dict`.
"""
if model_state_dict is None:
return
all_modules = set()
string_to_replace = f"{adapter_name}." if adapter_name else ""

for name in model_state_dict.keys():
if string_to_replace:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the if-statement is not needed here because string_to_replace will be an empty string, but no problem keeping as micro optimization

name = name.replace(string_to_replace, "")
if "." in name:
module_name = name.rsplit(".", 1)[0]
all_modules.add(module_name)

target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
exclude_modules = list(all_modules - target_modules_set)

return exclude_modules
6 changes: 5 additions & 1 deletion tests/lora/test_lora_layers_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
WanPipeline,
WanTransformer3DModel,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
skip_mps,
)


sys.path.append(".")
Expand Down
67 changes: 67 additions & 0 deletions tests/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
import re
Expand Down Expand Up @@ -291,6 +292,20 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):

return modules_to_save

def _get_exclude_modules(self, pipe):
from diffusers.utils.peft_utils import _derive_exclude_modules

modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
pipe.unload_lora_weights()
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
exclude_modules = _derive_exclude_modules(
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
)
return exclude_modules

def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
Expand Down Expand Up @@ -2326,6 +2341,58 @@ def test_lora_unload_add_adapter(self):
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]

@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules(self):
"""
Test to check if `exclude_modules` works or not. It works in the following way:
we first create a pipeline and insert LoRA config into it. We then derive a `set`
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
state dict.

We then create a new LoRA config to include the `exclude_modules` and perform tests.
"""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)

# only supported for `denoiser` now
pipe_cp = copy.deepcopy(pipe)
pipe_cp, _ = self.add_adapters_to_pipeline(
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid I don't fully understand how this test works but that shouldn't be a blocker. If this tests the exact same condition we're facing in the FusionX lora, then it should be good :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not the cleanest test this one. I added a description in the test, see if that helps?

Currently, we don't have:

If this tests the exact same condition we're facing in the FusionX lora

I will think of a way to include a test for that, too.

pipe_cp.to("cpu")
del pipe_cp

denoiser_lora_config.exclude_modules = denoiser_exclude_modules
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]

with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)

output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]

self.assertTrue(
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
"LoRA should change outputs.",
)
self.assertTrue(
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.",
)

def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
Expand Down