Skip to content

Commit f2ae8a1

Browse files
committed
Support pass kwargs to sd3 custom attention processor
1 parent 9a92b81 commit f2ae8a1

File tree

2 files changed

+95
-25
lines changed

2 files changed

+95
-25
lines changed

src/diffusers/models/attention.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
25+
from .normalization import (
26+
AdaLayerNorm,
27+
AdaLayerNormContinuous,
28+
AdaLayerNormZero,
29+
RMSNorm,
30+
SD35AdaLayerNormZeroX,
31+
)
2632

2733

2834
logger = logging.get_logger(__name__)
@@ -122,7 +128,12 @@ def __init__(
122128

123129
if context_norm_type == "ada_norm_continous":
124130
self.norm1_context = AdaLayerNormContinuous(
125-
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
131+
dim,
132+
dim,
133+
elementwise_affine=False,
134+
eps=1e-6,
135+
bias=True,
136+
norm_type="layer_norm",
126137
)
127138
elif context_norm_type == "ada_norm_zero":
128139
self.norm1_context = AdaLayerNormZero(dim)
@@ -188,33 +199,51 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188199
self._chunk_dim = dim
189200

190201
def forward(
191-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
202+
self,
203+
hidden_states: torch.FloatTensor,
204+
encoder_hidden_states: torch.FloatTensor,
205+
temb: torch.FloatTensor,
206+
joint_attention_kwargs: Dict[str, Any] = None,
192207
):
193208
if self.use_dual_attention:
194-
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
195-
hidden_states, emb=temb
196-
)
209+
(
210+
norm_hidden_states,
211+
gate_msa,
212+
shift_mlp,
213+
scale_mlp,
214+
gate_mlp,
215+
norm_hidden_states2,
216+
gate_msa2,
217+
) = self.norm1(hidden_states, emb=temb)
197218
else:
198219
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
199220

200221
if self.context_pre_only:
201222
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
202223
else:
203-
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
204-
encoder_hidden_states, emb=temb
205-
)
224+
(
225+
norm_encoder_hidden_states,
226+
c_gate_msa,
227+
c_shift_mlp,
228+
c_scale_mlp,
229+
c_gate_mlp,
230+
) = self.norm1_context(encoder_hidden_states, emb=temb)
231+
232+
joint_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
206233

207234
# Attention.
208235
attn_output, context_attn_output = self.attn(
209-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
236+
hidden_states=norm_hidden_states,
237+
encoder_hidden_states=norm_encoder_hidden_states,
238+
**joint_attention_kwargs,
210239
)
211240

212241
# Process attention outputs for the `hidden_states`.
213242
attn_output = gate_msa.unsqueeze(1) * attn_output
214243
hidden_states = hidden_states + attn_output
215244

216245
if self.use_dual_attention:
217-
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
246+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
218247
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
219248
hidden_states = hidden_states + attn_output2
220249

@@ -241,7 +270,10 @@ def forward(
241270
if self._chunk_size is not None:
242271
# "feed_forward_chunk_size" can be used to save memory
243272
context_ff_output = _chunked_feed_forward(
244-
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
273+
self.ff_context,
274+
norm_encoder_hidden_states,
275+
self._chunk_dim,
276+
self._chunk_size,
245277
)
246278
else:
247279
context_ff_output = self.ff_context(norm_encoder_hidden_states)
@@ -402,7 +434,7 @@ def __init__(
402434

403435
self.attn2 = Attention(
404436
query_dim=dim,
405-
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
437+
cross_attention_dim=(cross_attention_dim if not double_self_attention else None),
406438
heads=num_attention_heads,
407439
dim_head=attention_head_dim,
408440
dropout=dropout,
@@ -506,7 +538,7 @@ def forward(
506538

507539
attn_output = self.attn1(
508540
norm_hidden_states,
509-
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
541+
encoder_hidden_states=(encoder_hidden_states if self.only_cross_attention else None),
510542
attention_mask=attention_mask,
511543
**cross_attention_kwargs,
512544
)
@@ -979,7 +1011,7 @@ def __init__(
9791011

9801012
self.attn2 = Attention(
9811013
query_dim=dim,
982-
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1014+
cross_attention_dim=(cross_attention_dim if not double_self_attention else None),
9831015
heads=num_attention_heads,
9841016
dim_head=attention_head_dim,
9851017
dropout=dropout,
@@ -1045,7 +1077,10 @@ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid")
10451077
return weights
10461078

10471079
def set_free_noise_properties(
1048-
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1080+
self,
1081+
context_length: int,
1082+
context_stride: int,
1083+
weighting_scheme: str = "pyramid",
10491084
) -> None:
10501085
self.context_length = context_length
10511086
self.context_stride = context_stride
@@ -1112,7 +1147,7 @@ def forward(
11121147

11131148
attn_output = self.attn1(
11141149
norm_hidden_states,
1115-
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1150+
encoder_hidden_states=(encoder_hidden_states if self.only_cross_attention else None),
11161151
attention_mask=attention_mask,
11171152
**cross_attention_kwargs,
11181153
)
@@ -1158,7 +1193,11 @@ def forward(
11581193
# looked into this deeply because other memory optimizations led to more pronounced reductions.
11591194
hidden_states = torch.cat(
11601195
[
1161-
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1196+
torch.where(
1197+
num_times_split > 0,
1198+
accumulated_split / num_times_split,
1199+
accumulated_split,
1200+
)
11621201
for accumulated_split, num_times_split in zip(
11631202
accumulated_values.split(self.context_length, dim=1),
11641203
num_times_accumulated.split(self.context_length, dim=1),

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,20 @@
2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2323
from ...models.attention import JointTransformerBlock
24-
from ...models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
24+
from ...models.attention_processor import (
25+
Attention,
26+
AttentionProcessor,
27+
FusedJointAttnProcessor2_0,
28+
)
2529
from ...models.modeling_utils import ModelMixin
2630
from ...models.normalization import AdaLayerNormContinuous
27-
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
31+
from ...utils import (
32+
USE_PEFT_BACKEND,
33+
is_torch_version,
34+
logging,
35+
scale_lora_layers,
36+
unscale_lora_layers,
37+
)
2838
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
2939
from ..modeling_outputs import Transformer2DModelOutput
3040

@@ -88,7 +98,8 @@ def __init__(
8898
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
8999
)
90100
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
91-
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
101+
embedding_dim=self.inner_dim,
102+
pooled_projection_dim=self.config.pooled_projection_dim,
92103
)
93104
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
94105

@@ -166,7 +177,11 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]:
166177
# set recursively
167178
processors = {}
168179

169-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
180+
def fn_recursive_add_processors(
181+
name: str,
182+
module: torch.nn.Module,
183+
processors: Dict[str, AttentionProcessor],
184+
):
170185
if hasattr(module, "get_processor"):
171186
processors[f"{name}.processor"] = module.get_processor()
172187

@@ -334,12 +349,16 @@ def custom_forward(*inputs):
334349
hidden_states,
335350
encoder_hidden_states,
336351
temb,
352+
joint_attention_kwargs,
337353
**ckpt_kwargs,
338354
)
339355

340356
else:
341357
encoder_hidden_states, hidden_states = block(
342-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
358+
hidden_states=hidden_states,
359+
encoder_hidden_states=encoder_hidden_states,
360+
temb=temb,
361+
joint_attention_kwargs=joint_attention_kwargs,
343362
)
344363

345364
# controlnet residual
@@ -356,11 +375,23 @@ def custom_forward(*inputs):
356375
width = width // patch_size
357376

358377
hidden_states = hidden_states.reshape(
359-
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
378+
shape=(
379+
hidden_states.shape[0],
380+
height,
381+
width,
382+
patch_size,
383+
patch_size,
384+
self.out_channels,
385+
)
360386
)
361387
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
362388
output = hidden_states.reshape(
363-
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
389+
shape=(
390+
hidden_states.shape[0],
391+
self.out_channels,
392+
height * patch_size,
393+
width * patch_size,
394+
)
364395
)
365396

366397
if USE_PEFT_BACKEND:

0 commit comments

Comments
 (0)