Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature IP Adapter Xformers Attention Processor #9881

Merged
merged 10 commits into from
Nov 9, 2024

Conversation

elismasilva
Copy link
Contributor

@elismasilva elismasilva commented Nov 6, 2024

What does this PR do?

This PR is for fixing incorrect attention processor when setting Xformers attn after load ip adapter scale.

Solution was described on #8872.

Fixes #8863, #8872

Test Code

import numpy as np
import torch
from diffusers import AutoPipelineForText2Image
from transformers import CLIPVisionModelWithProjection
from diffusers.utils.loading_utils import load_image

MAX_SEED = np.iinfo(np.int32).max
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
device = "cuda"
seed = 42

image_encoder = CLIPVisionModelWithProjection.from_pretrained(
    "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16
).to(device)

# load SDXL pipeline
pipe = AutoPipelineForText2Image.from_pretrained(
    base_model_path,    
    torch_dtype=torch.float16,
    image_encoder=image_encoder
).to(device)

#DEFAULT RUNNING ATTENTION IS 2.0
pipe.enable_vae_tiling() 
pipe.enable_model_cpu_offload()
#pipe.enable_xformers_memory_efficient_attention() #UNCOMMENT this line to run WITH XFORMERS before the IP adapter is loaded.

# load ip-adapter
pipe.load_ip_adapter("h94/IP-Adapter",
    subfolder="sdxl_models",    
    #weight_name="ip-adapter-plus-face_sdxl_vit-h.bin",
    weight_name="ip-adapter-plus_sdxl_vit-h.bin",
    image_encoder_folder=None,
)

# configure ip-adapter scales.
scale = {
    #"down": {"block_2": [0.0, 1.0]}, #composition
    "up": {"block_0": [0.0, 1.0, 0.0]}, #style
}

pipe.set_ip_adapter_scale(scale)
pipe.enable_xformers_memory_efficient_attention() #UNCOMMENT this line to run WITH XFORMERS after the IP adapter is loaded.

#pipe.disable_xformers_memory_efficient_attention() #UNCOMMENT this line to DISABLE XFORMERS and back to Attention 2.0.

generator = torch.Generator(device).manual_seed(seed)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
image.resize((512, 512))

# generate image
# DO BREAKPOINT HERE AND CHECK ATTN PROCESSORS IN PIPE.UNET COMPONENT
images = pipe(
    prompt="a cat, masterpiece, best quality, high quality",
    negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
    ip_adapter_image=image,        
    guidance_scale=5,
    height=1024,
    width=1024,    
    num_inference_steps=30,     
    generator=generator
).images[0]

images.save("./data/result_1_diff.png")

Before submitting

Who can review?

requested by @a-r-r-o-w
@yiyixuxu and @asomoza

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR
I left some questions, most on the IP adapter loading part, if you can help me understand why you added the change (e.g. the use cases you have in mind when you add them), it would be really helpful!!

IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
if ('XFormers' not in str(self.attn_processors[name].__class__)):
attn_processor_class = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this code change here

  • previously, if I understand the code correctly, we keep the original attention processor for motion modules (do not change to IP adapter attention processor)
  • now, we change to the default attention processor when it is not Xformer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, how are you? Help me understand if my reasoning is correct. In this condition it is true when "cross_attention_dim" is None or "motion_modules" is present. I just paid attention to the scenario where "cross_attention_dim" is None. Is there any specific attention class for models with "motion_modules" other than AttnProcessor and AttnProcessor2_0? Because if there was I would just leave "motion_modules" in an "elif". But this part of the code is part of a first solution that I had implemented some time ago when I had not yet implemented the replacement of the attention mechanism in the "set_use_memory_efficient_attention_xformers" method of the "Attention" class. So at the time when I was testing several adapters and combined adapters I was probably encountering a situation that made me force this xformers check in this part of the code. However, now that you mentioned it, I decided to comment out this part of the code and perform some more tests, and it seems that this modification is no longer necessary since "set_use_memory_efficient_attention_xformers" has been implemented. At least for now, I haven't run into any error situations when loading.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you agree i will commit updates without this verification to original code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, can you provide a code example that would fail without this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, can you provide a code example that would fail without this change?

yes see my code on PR if you check i have lines #pipe.enable_xformers_memory_efficient_attention() you can remove # to run before or after load PR i put the two lines before and after loading model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will commit my lasted code with some fixes for quality check

attn_procs[name] = attn_processor_class()
else:
attn_procs[name] = self.attn_processors[name]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the change here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here in Else: it seems to me a situation of which comes first, the chicken or the egg? When we do not use the pipe.enable_xformers_memory_efficient_attention() call by default it is defining IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor for the IP Adapter. So when you call pipe.enable_xformers_memory_efficient_attention() before loading the IP Adapter all attn are defined for XFormersAttnProcessor, so when loading the IP Adapter modules after this call it is necessary to check if the defined mechanism is xformers to apply the new class "IPAdapterXFormersAttnProcessor". However, when you call pipe.enable_xformers_memory_efficient_attention() after loading the IP Adapter modules, the modules had already been set by default to "IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor " and the "set_use_memory_efficient_attention_xformers" method of the "Attention" class only knows how to set everything to XFormersAttnProcessor and this generated the error that was reported in the open issue. Now with the implementation that I made in this class, the method also knows how to identify "IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor " in the modules and correctly replace them with the new class. But it only knows how to do this because "IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor " was defined when loading the module. So these checks are necessary on both sides due to the order in which pipe.enable_xformers_memory_efficient_attention() is called, before or after loading the modules.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that you added this change in order to be able to handle this?

pipe.enable_xformers_memory_efficient_attention()
pipe.load_ip_adapter()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is to solve this order, and vice-versa. My provided code in this PR simulate the two scenarios

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 8, 2024

cc @fabiorigano here too if you have time to give this a review!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for making the code simpler! i left a few more quetions

@@ -369,7 +369,20 @@ def set_use_memory_efficient_attention_xformers(
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
processor = self.processor
if isinstance(self.processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a is_ip_adapter flag similar to is_custom_diffusion etc

is_ip_adapter = hasattr(self, "processor") and isinstance(
            self.processor,( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0),
        )

Copy link
Contributor Author

@elismasilva elismasilva Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is perfectly possible, but it will have to be like below, so that the modules that have already been changed to the Xformers attention class are not replaced again to the XFormersAttnProcessor class in the final Else during the method recursion.

is_ip_adapter = hasattr(self, "processor") and isinstance(
            self.processor, 
            (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
        ) 

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the code you show here is ok!
we just want to keep a consistent style that's all:)

scale=self.processor.scale,
num_tokens=self.processor.num_tokens,
attention_op=attention_op)
processor.load_state_dict(self.processor.state_dict())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to load_state_dict again here?

Copy link
Contributor Author

@elismasilva elismasilva Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I couldn't identify why, but if I don't reload the state_dict again here after assigning the new class, the final result in the image is not applied. I don't know if it's because the call to "pipe.enable_xformers_memory_efficient_attention()" was after the IP adapter weights had already been loaded, so it's as if the model was not being used. I saw that during the loading of the IP adapter weights you do some manipulations, but I don't think it makes sense to replicate that logic here and I don't know that's the reason. See a final image when there is no state dict and another when there is. So I noticed that in custom diffusion something similar is done, so for practicality I decided to do the same. If you have a better solution I would like to try it.

Without load_sate_dict:
result_1_diff

With load_state_dict:
result_1_diff

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah it comes with weights

self.to_k_ip = nn.ModuleList(

(i had forgotten about that sorry! lol)

num_tokens=self.processor.num_tokens,
attention_op=attention_op)
processor.load_state_dict(self.processor.state_dict())
if len(self.processor._modules) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this section of code do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is passing on the initialization parameters that were already defined during the loading of the ip adapter model to the new xformers attention class. After this I reload the state_dict already loaded in the new object, as already explained in the previous question. Then I just make sure that the weights are in the same device and dtype that were previously, because when reloading the state_dict they are placed in "cpu" and with dtype "float32".
I changed this initial if statement in line 380 to check for existing modules just to avoid unexpected errors.

if hasattr(self.processor, "_modules") and len(self.processor._modules) > 0:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I think we can simplify the code here a little bit because we are inside an if statement here so we already know the processor will be either IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0 or IPAdapterXFormersAttnProcessor -- in all of these 3 cases it will have a to_k_ip layer and to_v_ip layer, so maybe we can just get device info from self.to_k_ip[0].device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, I had done it in an agnostic way, because I wasn't sure that there would always be these modules in all the models that might arrive there. I'll change it to your solution.

src/diffusers/models/attention_processor.py Show resolved Hide resolved
Copy link
Contributor

@fabiorigano fabiorigano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elismasilva great work! I left one comment. I agree with YiYi's feedback on src/diffusers/models/attention_processor.py, while I think it's fine to keep the class name as originally defined
@yiyixuxu thank you for letting me review this PR :)

def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, attention_op: Optional[Callable] = None):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check can be removed because this class uses xformers.ops.memory_efficient_attention instead of torch.nn.functional.scaled_dot_product_attention

Copy link
Contributor

@fabiorigano fabiorigano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just adding one more consideration

Comment on lines 4677 to 4681
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op)

just another observation: if we make tensors contiguous here, we can avoid multiple calls to query.contiguous() later in the code (everytime self. _memory_efficient_attention_xformers is called, query is reused)
this way, we can directly call xformers.ops.memory_efficient_attention

Comment on lines 4741 to 4744
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

_current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
_current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op)

same as before

Comment on lines 4762 to 4765
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!Done!

@elismasilva
Copy link
Contributor Author

thanks for making the code simpler! i left a few more quetions

Thanks for the reviews guys, I will test the proposed scenarios and come back with comments and changes later.

@elismasilva
Copy link
Contributor Author

Well, i did all changes and is working, i will wait your replies to do commit about my changes.

@elismasilva
Copy link
Contributor Author

Well, i did all changes and is working, i will wait your replies to do commit about my changes.

@yiyixuxu @fabiorigano changes commited!

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 8, 2024

can you run make style?

@elismasilva
Copy link
Contributor Author

can you run make style?

i can run only for these 3 files ? because it is change a huge number of files.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 8, 2024

it should not change any files that this PR does not touch if you set up our dev environment correctly! https://huggingface.co/docs/diffusers/en/conceptual/contribution#how-to-open-a-pr

happy to help make style too if you need it

@elismasilva
Copy link
Contributor Author

it should not change any files that this PR does not touch if you set up our dev environment correctly! https://huggingface.co/docs/diffusers/en/conceptual/contribution#how-to-open-a-pr

happy to help make style too if you need it

yep i followed this guide but its changing about 374 files, i am running on windows, i will test it on WSL-2 maybe could be a environment problem.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 8, 2024

ohh actually I cannot push into the PR so I cannot help, maybe try to only add changes from these 3 files to see if CI passed

cc @asomoza for tips with windows

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 8, 2024

you can cherry-pick this commit here if you want c2d1531

but it is better if you are able to figure out what's wrong with your environment so that in the future we will be able to merge your PRs!

@elismasilva
Copy link
Contributor Author

you can cherry-pick this commit here if you want c2d1531

but it is better if you are able to figure out what's wrong with your environment so that in the future we will be able to merge your PRs!

on my WSL its is working, i will push changes again.

@elismasilva
Copy link
Contributor Author

@yiyixuxu ive implemented and tested the missing part, when disabling xformers it needed to revert to Attention 2.0. It was my last commit. I think we are done now.

@elismasilva elismasilva requested a review from yiyixuxu November 8, 2024 22:52
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu merged commit dac623b into huggingface:main Nov 9, 2024
15 checks passed
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Feature IP Adapter Xformers Attention Processor: this fix error loading incorrect attention processor when setting Xformers attn after load ip adapter scale, issues: #8863 #8872
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AttributeError: 'tuple' object has no attribute 'shape'
4 participants