Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature IP Adapter Xformers Attention Processor #9881

Merged
merged 10 commits into from
Nov 9, 2024
11 changes: 5 additions & 6 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@


if is_transformers_available():
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)

from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -284,7 +283,7 @@ def set_ip_adapter_scale(self, scale):
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)

for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
Expand Down Expand Up @@ -342,7 +341,7 @@ def unload_ip_adapter(self):
)
attn_procs[name] = (
attn_processor_class
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor))
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
7 changes: 5 additions & 2 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)

if low_cpu_mem_usage:
Expand Down Expand Up @@ -804,9 +805,11 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()

else:
attn_processor_class = (
if ('XFormers' in str(self.attn_processors[name].__class__)):
attn_processor_class = (IPAdapterXFormersAttnProcessor)
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
num_image_text_embeds = []
Expand Down
243 changes: 242 additions & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,20 @@ def set_use_memory_efficient_attention_xformers(
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
processor = self.processor
if isinstance(self.processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a is_ip_adapter flag similar to is_custom_diffusion etc

is_ip_adapter = hasattr(self, "processor") and isinstance(
            self.processor,( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0),
        )

Copy link
Contributor Author

@elismasilva elismasilva Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is perfectly possible, but it will have to be like below, so that the modules that have already been changed to the Xformers attention class are not replaced again to the XFormersAttnProcessor class in the final Else during the method recursion.

is_ip_adapter = hasattr(self, "processor") and isinstance(
            self.processor, 
            (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
        ) 

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the code you show here is ok!
we just want to keep a consistent style that's all:)

processor = IPAdapterXFormersAttnProcessor(hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
scale=self.processor.scale,
num_tokens=self.processor.num_tokens,
attention_op=attention_op)
processor.load_state_dict(self.processor.state_dict())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to load_state_dict again here?

Copy link
Contributor Author

@elismasilva elismasilva Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I couldn't identify why, but if I don't reload the state_dict again here after assigning the new class, the final result in the image is not applied. I don't know if it's because the call to "pipe.enable_xformers_memory_efficient_attention()" was after the IP adapter weights had already been loaded, so it's as if the model was not being used. I saw that during the loading of the IP adapter weights you do some manipulations, but I don't think it makes sense to replicate that logic here and I don't know that's the reason. See a final image when there is no state dict and another when there is. So I noticed that in custom diffusion something similar is done, so for practicality I decided to do the same. If you have a better solution I would like to try it.

Without load_sate_dict:
result_1_diff

With load_state_dict:
result_1_diff

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah it comes with weights

self.to_k_ip = nn.ModuleList(

(i had forgotten about that sorry! lol)

if len(self.processor._modules) > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this section of code do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is passing on the initialization parameters that were already defined during the loading of the ip adapter model to the new xformers attention class. After this I reload the state_dict already loaded in the new object, as already explained in the previous question. Then I just make sure that the weights are in the same device and dtype that were previously, because when reloading the state_dict they are placed in "cpu" and with dtype "float32".
I changed this initial if statement in line 380 to check for existing modules just to avoid unexpected errors.

if hasattr(self.processor, "_modules") and len(self.processor._modules) > 0:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I think we can simplify the code here a little bit because we are inside an if statement here so we already know the processor will be either IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0 or IPAdapterXFormersAttnProcessor -- in all of these 3 cases it will have a to_k_ip layer and to_v_ip layer, so maybe we can just get device info from self.to_k_ip[0].device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, I had done it in an agnostic way, because I wasn't sure that there would always be these modules in all the models that might arrive there. I'll change it to your solution.

module_list = list(self.processor._modules)
if len(module_list) > 0:
processor.to(device=self.processor._modules[module_list[0]][0].weight.device, dtype=self.processor._modules[module_list[0]][0].weight.dtype)
elif isinstance(self.processor, (AttnProcessor, AttnProcessor2_0)):
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
if is_custom_diffusion:
attn_processor_class = (
Expand Down Expand Up @@ -4541,7 +4554,235 @@ def __call__(

return hidden_states

class IPAdapterXFormersAttnProcessor(torch.nn.Module):
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
r"""
Attention processor for IP-Adapter using xFormers.

Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, attention_op: Optional[Callable] = None):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check can be removed because this class uses xformers.ops.memory_efficient_attention instead of torch.nn.functional.scaled_dot_product_attention

raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.attention_op = attention_op

if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens

if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale

self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)

def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op)
return hidden_states

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op)

just another observation: if we make tensors contiguous here, we can avoid multiple calls to query.contiguous() later in the code (everytime self. _memory_efficient_attention_xformers is called, query is reused)
this way, we can directly call xformers.ops.memory_efficient_attention


hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

if ip_hidden_states:
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
mask = mask.to(torch.float16)
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

_current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
_current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op)

same as before


_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!Done!


current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

hidden_states = hidden_states + scale * current_ip_hidden_states

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
Expand Down