Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] fix broken xformers tests #9206

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,6 +2011,11 @@ def __call__(
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()

if attn.norm_q is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which test is this meant to fix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CogVideoX tests. This was needed because we added QK norm in Attn2.0 and FusedAttn2.0.

However, after the new CogVideoX-5B PR, I think we can no longer support XFormers due to needing a custom attention processor. Maybe we can skip the test here because otherwise we'd need a custom XFormers processor variant for it?

query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
Expand Down
8 changes: 8 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
Expand Down Expand Up @@ -329,6 +330,13 @@ def test_prompt_embeds(self):
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
pipe(**inputs)

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)

def test_free_init(self):
components = self.get_dummy_components()
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
Expand Down
8 changes: 8 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
UNetMotionModel,
)
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
Expand Down Expand Up @@ -393,6 +394,13 @@ def test_prompt_embeds(self):
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
pipe(**inputs)

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)

def test_free_init(self):
components = self.get_dummy_components()
pipe: AnimateDiffSparseControlNetPipeline = self.pipeline_class(**components)
Expand Down
10 changes: 9 additions & 1 deletion tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,7 +1687,15 @@ def _test_xformers_attention_forwardGenerator_pass(
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")

if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
if torch.is_tensor(output_without_offload):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the issue is the output shapes it would be better to redefine the tests for that output shape inside the pipeline test modules. Similar to what's done here:

def test_xformers_attention_forwardGenerator_pass(self):

if output_without_offload.ndim == 5:
# Educated guess that the original format here is [B, F, C, H, W] and we
# permute to [B, F, H, W, C] to make input compatible with mean pixel difference
output_without_offload = output_without_offload.permute(0, 1, 3, 4, 2)[0]
output_with_offload = output_with_offload.permute(0, 1, 3, 4, 2)[0]
output_without_offload = to_np(output_without_offload)
output_with_offload = to_np(output_with_offload)
assert_mean_pixel_difference(to_np(output_with_offload[0]), to_np(output_without_offload[0]))

def test_progress_bar(self):
components = self.get_dummy_components()
Expand Down
Loading