-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature IP Adapter Xformers Attention Processor #9881
Changes from 4 commits
ac1c26d
a7af2b2
4475c0b
cd8702e
89f548c
37444bc
7741fb0
b01f302
4e9e4e0
3c66f70
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -369,7 +369,20 @@ def set_use_memory_efficient_attention_xformers( | |||||||||||||||||||||
) | ||||||||||||||||||||||
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
processor = XFormersAttnProcessor(attention_op=attention_op) | ||||||||||||||||||||||
processor = self.processor | ||||||||||||||||||||||
if isinstance(self.processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): | ||||||||||||||||||||||
processor = IPAdapterXFormersAttnProcessor(hidden_size=self.processor.hidden_size, | ||||||||||||||||||||||
cross_attention_dim=self.processor.cross_attention_dim, | ||||||||||||||||||||||
scale=self.processor.scale, | ||||||||||||||||||||||
num_tokens=self.processor.num_tokens, | ||||||||||||||||||||||
attention_op=attention_op) | ||||||||||||||||||||||
processor.load_state_dict(self.processor.state_dict()) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, I couldn't identify why, but if I don't reload the state_dict again here after assigning the new class, the final result in the image is not applied. I don't know if it's because the call to "pipe.enable_xformers_memory_efficient_attention()" was after the IP adapter weights had already been loaded, so it's as if the model was not being used. I saw that during the loading of the IP adapter weights you do some manipulations, but I don't think it makes sense to replicate that logic here and I don't know that's the reason. See a final image when there is no state dict and another when there is. So I noticed that in custom diffusion something similar is done, so for practicality I decided to do the same. If you have a better solution I would like to try it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yeah it comes with weights
(i had forgotten about that sorry! lol) |
||||||||||||||||||||||
if len(self.processor._modules) > 0: | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does this section of code do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is passing on the initialization parameters that were already defined during the loading of the ip adapter model to the new xformers attention class. After this I reload the state_dict already loaded in the new object, as already explained in the previous question. Then I just make sure that the weights are in the same device and dtype that were previously, because when reloading the state_dict they are placed in "cpu" and with dtype "float32". if hasattr(self.processor, "_modules") and len(self.processor._modules) > 0: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok I think we can simplify the code here a little bit because we are inside an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes, I had done it in an agnostic way, because I wasn't sure that there would always be these modules in all the models that might arrive there. I'll change it to your solution. |
||||||||||||||||||||||
module_list = list(self.processor._modules) | ||||||||||||||||||||||
if len(module_list) > 0: | ||||||||||||||||||||||
processor.to(device=self.processor._modules[module_list[0]][0].weight.device, dtype=self.processor._modules[module_list[0]][0].weight.dtype) | ||||||||||||||||||||||
elif isinstance(self.processor, (AttnProcessor, AttnProcessor2_0)): | ||||||||||||||||||||||
processor = XFormersAttnProcessor(attention_op=attention_op) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
if is_custom_diffusion: | ||||||||||||||||||||||
attn_processor_class = ( | ||||||||||||||||||||||
|
@@ -4541,7 +4554,235 @@ def __call__( | |||||||||||||||||||||
|
||||||||||||||||||||||
return hidden_states | ||||||||||||||||||||||
|
||||||||||||||||||||||
class IPAdapterXFormersAttnProcessor(torch.nn.Module): | ||||||||||||||||||||||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
r""" | ||||||||||||||||||||||
Attention processor for IP-Adapter using xFormers. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Args: | ||||||||||||||||||||||
hidden_size (`int`): | ||||||||||||||||||||||
The hidden size of the attention layer. | ||||||||||||||||||||||
cross_attention_dim (`int`): | ||||||||||||||||||||||
The number of channels in the `encoder_hidden_states`. | ||||||||||||||||||||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`): | ||||||||||||||||||||||
The context length of the image features. | ||||||||||||||||||||||
scale (`float` or `List[float]`, defaults to 1.0): | ||||||||||||||||||||||
the weight scale of image prompt. | ||||||||||||||||||||||
attention_op (`Callable`, *optional*, defaults to `None`): | ||||||||||||||||||||||
The base | ||||||||||||||||||||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use | ||||||||||||||||||||||
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, attention_op: Optional[Callable] = None): | ||||||||||||||||||||||
super().__init__() | ||||||||||||||||||||||
|
||||||||||||||||||||||
if not hasattr(F, "scaled_dot_product_attention"): | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this check can be removed because this class uses |
||||||||||||||||||||||
raise ImportError( | ||||||||||||||||||||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.hidden_size = hidden_size | ||||||||||||||||||||||
self.cross_attention_dim = cross_attention_dim | ||||||||||||||||||||||
self.attention_op = attention_op | ||||||||||||||||||||||
|
||||||||||||||||||||||
if not isinstance(num_tokens, (tuple, list)): | ||||||||||||||||||||||
num_tokens = [num_tokens] | ||||||||||||||||||||||
self.num_tokens = num_tokens | ||||||||||||||||||||||
|
||||||||||||||||||||||
if not isinstance(scale, list): | ||||||||||||||||||||||
scale = [scale] * len(num_tokens) | ||||||||||||||||||||||
if len(scale) != len(num_tokens): | ||||||||||||||||||||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") | ||||||||||||||||||||||
self.scale = scale | ||||||||||||||||||||||
|
||||||||||||||||||||||
self.to_k_ip = nn.ModuleList( | ||||||||||||||||||||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))] | ||||||||||||||||||||||
) | ||||||||||||||||||||||
self.to_v_ip = nn.ModuleList( | ||||||||||||||||||||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))] | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): | ||||||||||||||||||||||
# TODO attention_mask | ||||||||||||||||||||||
query = query.contiguous() | ||||||||||||||||||||||
key = key.contiguous() | ||||||||||||||||||||||
value = value.contiguous() | ||||||||||||||||||||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op) | ||||||||||||||||||||||
return hidden_states | ||||||||||||||||||||||
|
||||||||||||||||||||||
def __call__( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
attn: Attention, | ||||||||||||||||||||||
hidden_states: torch.FloatTensor, | ||||||||||||||||||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None, | ||||||||||||||||||||||
attention_mask: Optional[torch.FloatTensor] = None, | ||||||||||||||||||||||
temb: Optional[torch.FloatTensor] = None, | ||||||||||||||||||||||
scale: float = 1.0, | ||||||||||||||||||||||
ip_adapter_masks: Optional[torch.FloatTensor] = None, | ||||||||||||||||||||||
): | ||||||||||||||||||||||
residual = hidden_states | ||||||||||||||||||||||
|
||||||||||||||||||||||
# separate ip_hidden_states from encoder_hidden_states | ||||||||||||||||||||||
if encoder_hidden_states is not None: | ||||||||||||||||||||||
if isinstance(encoder_hidden_states, tuple): | ||||||||||||||||||||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
deprecation_message = ( | ||||||||||||||||||||||
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release." | ||||||||||||||||||||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) | ||||||||||||||||||||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] | ||||||||||||||||||||||
encoder_hidden_states, ip_hidden_states = ( | ||||||||||||||||||||||
encoder_hidden_states[:, :end_pos, :], | ||||||||||||||||||||||
[encoder_hidden_states[:, end_pos:, :]], | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if attn.spatial_norm is not None: | ||||||||||||||||||||||
hidden_states = attn.spatial_norm(hidden_states, temb) | ||||||||||||||||||||||
|
||||||||||||||||||||||
input_ndim = hidden_states.ndim | ||||||||||||||||||||||
|
||||||||||||||||||||||
if input_ndim == 4: | ||||||||||||||||||||||
batch_size, channel, height, width = hidden_states.shape | ||||||||||||||||||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | ||||||||||||||||||||||
|
||||||||||||||||||||||
batch_size, sequence_length, _ = ( | ||||||||||||||||||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if attention_mask is not None: | ||||||||||||||||||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||||||||||||||||||||||
# expand our mask's singleton query_tokens dimension: | ||||||||||||||||||||||
# [batch*heads, 1, key_tokens] -> | ||||||||||||||||||||||
# [batch*heads, query_tokens, key_tokens] | ||||||||||||||||||||||
# so that it can be added as a bias onto the attention scores that xformers computes: | ||||||||||||||||||||||
# [batch*heads, query_tokens, key_tokens] | ||||||||||||||||||||||
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us. | ||||||||||||||||||||||
_, query_tokens, _ = hidden_states.shape | ||||||||||||||||||||||
attention_mask = attention_mask.expand(-1, query_tokens, -1) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if attn.group_norm is not None: | ||||||||||||||||||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | ||||||||||||||||||||||
|
||||||||||||||||||||||
query = attn.to_q(hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if encoder_hidden_states is None: | ||||||||||||||||||||||
encoder_hidden_states = hidden_states | ||||||||||||||||||||||
elif attn.norm_cross: | ||||||||||||||||||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
key = attn.to_k(encoder_hidden_states) | ||||||||||||||||||||||
value = attn.to_v(encoder_hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
query = attn.head_to_batch_dim(query) | ||||||||||||||||||||||
key = attn.head_to_batch_dim(key) | ||||||||||||||||||||||
value = attn.head_to_batch_dim(value) | ||||||||||||||||||||||
|
||||||||||||||||||||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
just another observation: if we make tensors contiguous here, we can avoid multiple calls to |
||||||||||||||||||||||
|
||||||||||||||||||||||
hidden_states = hidden_states.to(query.dtype) | ||||||||||||||||||||||
hidden_states = attn.batch_to_head_dim(hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if ip_hidden_states: | ||||||||||||||||||||||
if ip_adapter_masks is not None: | ||||||||||||||||||||||
if not isinstance(ip_adapter_masks, List): | ||||||||||||||||||||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] | ||||||||||||||||||||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) | ||||||||||||||||||||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " | ||||||||||||||||||||||
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " | ||||||||||||||||||||||
f"({len(ip_hidden_states)})" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): | ||||||||||||||||||||||
if mask is None: | ||||||||||||||||||||||
continue | ||||||||||||||||||||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4: | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
"Each element of the ip_adapter_masks array should be a tensor with shape " | ||||||||||||||||||||||
"[1, num_images_for_ip_adapter, height, width]." | ||||||||||||||||||||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
if mask.shape[1] != ip_state.shape[1]: | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
f"Number of masks ({mask.shape[1]}) does not match " | ||||||||||||||||||||||
f"number of ip images ({ip_state.shape[1]}) at index {index}" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
if isinstance(scale, list) and not len(scale) == mask.shape[1]: | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
f"Number of masks ({mask.shape[1]}) does not match " | ||||||||||||||||||||||
f"number of scales ({len(scale)}) at index {index}" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
ip_adapter_masks = [None] * len(self.scale) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# for ip-adapter | ||||||||||||||||||||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( | ||||||||||||||||||||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks | ||||||||||||||||||||||
): | ||||||||||||||||||||||
skip = False | ||||||||||||||||||||||
if isinstance(scale, list): | ||||||||||||||||||||||
if all(s == 0 for s in scale): | ||||||||||||||||||||||
skip = True | ||||||||||||||||||||||
elif scale == 0: | ||||||||||||||||||||||
skip = True | ||||||||||||||||||||||
if not skip: | ||||||||||||||||||||||
if mask is not None: | ||||||||||||||||||||||
mask = mask.to(torch.float16) | ||||||||||||||||||||||
if not isinstance(scale, list): | ||||||||||||||||||||||
scale = [scale] * mask.shape[1] | ||||||||||||||||||||||
|
||||||||||||||||||||||
current_num_images = mask.shape[1] | ||||||||||||||||||||||
for i in range(current_num_images): | ||||||||||||||||||||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) | ||||||||||||||||||||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) | ||||||||||||||||||||||
|
||||||||||||||||||||||
ip_key = attn.head_to_batch_dim(ip_key) | ||||||||||||||||||||||
ip_value = attn.head_to_batch_dim(ip_value) | ||||||||||||||||||||||
|
||||||||||||||||||||||
_current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
same as before |
||||||||||||||||||||||
|
||||||||||||||||||||||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) | ||||||||||||||||||||||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
mask_downsample = IPAdapterMaskProcessor.downsample( | ||||||||||||||||||||||
mask[:, i, :, :], | ||||||||||||||||||||||
batch_size, | ||||||||||||||||||||||
_current_ip_hidden_states.shape[1], | ||||||||||||||||||||||
_current_ip_hidden_states.shape[2], | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) | ||||||||||||||||||||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
ip_key = to_k_ip(current_ip_hidden_states) | ||||||||||||||||||||||
ip_value = to_v_ip(current_ip_hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
ip_key = attn.head_to_batch_dim(ip_key) | ||||||||||||||||||||||
ip_value = attn.head_to_batch_dim(ip_value) | ||||||||||||||||||||||
|
||||||||||||||||||||||
current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you!Done! |
||||||||||||||||||||||
|
||||||||||||||||||||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) | ||||||||||||||||||||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
hidden_states = hidden_states + scale * current_ip_hidden_states | ||||||||||||||||||||||
|
||||||||||||||||||||||
# linear proj | ||||||||||||||||||||||
hidden_states = attn.to_out[0](hidden_states) | ||||||||||||||||||||||
# dropout | ||||||||||||||||||||||
hidden_states = attn.to_out[1](hidden_states) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if input_ndim == 4: | ||||||||||||||||||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if attn.residual_connection: | ||||||||||||||||||||||
hidden_states = hidden_states + residual | ||||||||||||||||||||||
|
||||||||||||||||||||||
hidden_states = hidden_states / attn.rescale_output_factor | ||||||||||||||||||||||
|
||||||||||||||||||||||
return hidden_states | ||||||||||||||||||||||
class PAGIdentitySelfAttnProcessor2_0: | ||||||||||||||||||||||
r""" | ||||||||||||||||||||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | ||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add a
is_ip_adapter
flag similar to is_custom_diffusion etcThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is perfectly possible, but it will have to be like below, so that the modules that have already been changed to the Xformers attention class are not replaced again to the XFormersAttnProcessor class in the final Else during the method recursion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the code you show here is ok!
we just want to keep a consistent style that's all:)