Skip to content

Commit 9424087

Browse files
piercustolgacangoz
authored andcommitted
Flux: pass joint_attention_kwargs when using gradient_checkpointing (huggingface#11814)
Flux: pass joint_attention_kwargs when gradient_checkpointing
1 parent a8f1423 commit 9424087

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ def forward(
490490
encoder_hidden_states,
491491
temb,
492492
image_rotary_emb,
493+
joint_attention_kwargs,
493494
)
494495

495496
else:
@@ -521,6 +522,7 @@ def forward(
521522
encoder_hidden_states,
522523
temb,
523524
image_rotary_emb,
525+
joint_attention_kwargs,
524526
)
525527

526528
else:

0 commit comments

Comments
 (0)