Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions tests/lora/test_lora_layers_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
import sys
import unittest

import numpy as np
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration

from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel

from ..testing_utils import floats_tensor, require_peft_backend
from ..testing_utils import floats_tensor, require_peft_backend, torch_device


sys.path.append(".")

from .utils import PeftLoraLoaderMixinTests # noqa: E402
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402


@require_peft_backend
Expand Down Expand Up @@ -94,6 +95,44 @@ def get_dummy_inputs(self, with_generator=True):

return noise, input_ids, pipeline_inputs

def test_lora_fuse_nan(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)

denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")

# corrupt one LoRA weight with `inf` values
with torch.no_grad():
possible_tower_names = ["transformer_blocks", "single_transformer_blocks"]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
if len(filtered_tower_names) == 0:
reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
raise ValueError(reason)
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
is_single = "single" in tower_name
if is_single:
transformer_tower[0].attn.to_qkv_mlp_proj.lora_A["adapter-1"].weight += float("inf")
else:
transformer_tower[0].attn.to_k.lora_A["adapter-1"].weight += float("inf")

# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)

# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]

self.assertTrue(np.isnan(out).all())

@unittest.skip("Not supported in Flux2.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
Expand Down
Loading