Skip to content

Commit 83bfc42

Browse files
committed
ZImageTransformer2D: Only build attention mask if seqlens are not equal
1 parent 99daaa8 commit 83bfc42

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -788,9 +788,12 @@ def _prepare_sequence(
788788
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
789789

790790
# Attention mask
791-
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
792-
for i, seq_len in enumerate(item_seqlens):
793-
attn_mask[i, :seq_len] = 1
791+
if all(seq == max_seqlen for seq in item_seqlens):
792+
attn_mask = None
793+
else:
794+
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
795+
for i, seq_len in enumerate(item_seqlens):
796+
attn_mask[i, :seq_len] = 1
794797

795798
# Noise mask
796799
noise_mask_tensor = None
@@ -871,9 +874,12 @@ def _build_unified_sequence(
871874
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
872875

873876
# Attention mask
874-
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
875-
for i, seq_len in enumerate(unified_seqlens):
876-
attn_mask[i, :seq_len] = 1
877+
if all(seq == max_seqlen for seq in unified_seqlens):
878+
attn_mask = None
879+
else:
880+
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
881+
for i, seq_len in enumerate(unified_seqlens):
882+
attn_mask[i, :seq_len] = 1
877883

878884
# Noise mask
879885
noise_mask_tensor = None

0 commit comments

Comments
 (0)