Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyixuxu committed Nov 8, 2024
1 parent 7741fb0 commit c2d1531
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 17 deletions.
9 changes: 6 additions & 3 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@


if is_transformers_available():

from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from ..models.attention_processor import (
Expand Down Expand Up @@ -283,7 +282,9 @@ def set_ip_adapter_scale(self, scale):
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)

for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)):
if isinstance(
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
Expand Down Expand Up @@ -341,7 +342,9 @@ def unload_ip_adapter(self):
)
attn_procs[name] = (
attn_processor_class
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor))
if isinstance(
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
)
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
10 changes: 6 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,12 +806,14 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
if ('XFormers' in str(self.attn_processors[name].__class__)):
attn_processor_class = (IPAdapterXFormersAttnProcessor)
if "XFormers" in str(self.attn_processors[name].__class__):
attn_processor_class = IPAdapterXFormersAttnProcessor
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
IPAdapterAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else IPAdapterAttnProcessor
)
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
Expand Down
41 changes: 31 additions & 10 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,18 @@ def set_use_memory_efficient_attention_xformers(
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
elif is_ip_adapter:
processor = IPAdapterXFormersAttnProcessor(hidden_size=self.processor.hidden_size,
processor = IPAdapterXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
scale=self.processor.scale,
num_tokens=self.processor.num_tokens,
attention_op=attention_op)
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_ip"):
processor.to(device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype)
processor.to(
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
Expand Down Expand Up @@ -4569,10 +4573,19 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
the weight scale of image prompt.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, attention_op: Optional[Callable] = None):

def __init__(
self,
hidden_size,
cross_attention_dim=None,
num_tokens=(4,),
scale=1.0,
attention_op: Optional[Callable] = None,
):
super().__init__()

self.hidden_size = hidden_size
Expand Down Expand Up @@ -4665,7 +4678,9 @@ def __call__(
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()

hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

Expand All @@ -4681,7 +4696,9 @@ def __call__(
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
for index, (mask, scale, ip_state) in enumerate(
zip(ip_adapter_masks, self.scale, ip_hidden_states)
):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
Expand Down Expand Up @@ -4727,7 +4744,9 @@ def __call__(
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()

_current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op)
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
query, ip_key, ip_value, op=self.attention_op
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

Expand All @@ -4747,7 +4766,9 @@ def __call__(
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()

current_ip_hidden_states = xformers.ops.memory_efficient_attention(query, ip_key, ip_value, op=self.attention_op)
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
query, ip_key, ip_value, op=self.attention_op
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

Expand Down

0 comments on commit c2d1531

Please sign in to comment.