Skip to content

Commit 8eb73c8

Browse files
Matrix53hlkyyiyixuxu
authored
Support pass kwargs to sd3 custom attention processor (#9818)
* Support pass kwargs to sd3 custom attention processor --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 88b015d commit 8eb73c8

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

src/diffusers/models/attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188188
self._chunk_dim = dim
189189

190190
def forward(
191-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
191+
self,
192+
hidden_states: torch.FloatTensor,
193+
encoder_hidden_states: torch.FloatTensor,
194+
temb: torch.FloatTensor,
195+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
192196
):
197+
joint_attention_kwargs = joint_attention_kwargs or {}
193198
if self.use_dual_attention:
194199
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
195200
hidden_states, emb=temb
@@ -206,15 +211,17 @@ def forward(
206211

207212
# Attention.
208213
attn_output, context_attn_output = self.attn(
209-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
214+
hidden_states=norm_hidden_states,
215+
encoder_hidden_states=norm_encoder_hidden_states,
216+
**joint_attention_kwargs,
210217
)
211218

212219
# Process attention outputs for the `hidden_states`.
213220
attn_output = gate_msa.unsqueeze(1) * attn_output
214221
hidden_states = hidden_states + attn_output
215222

216223
if self.use_dual_attention:
217-
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
224+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
218225
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
219226
hidden_states = hidden_states + attn_output2
220227

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,15 @@ def custom_forward(*inputs):
411411
hidden_states,
412412
encoder_hidden_states,
413413
temb,
414+
joint_attention_kwargs,
414415
**ckpt_kwargs,
415416
)
416417
elif not is_skip:
417418
encoder_hidden_states, hidden_states = block(
418-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
419+
hidden_states=hidden_states,
420+
encoder_hidden_states=encoder_hidden_states,
421+
temb=temb,
422+
joint_attention_kwargs=joint_attention_kwargs,
419423
)
420424

421425
# controlnet residual

0 commit comments

Comments
 (0)