@@ -788,9 +788,12 @@ def _prepare_sequence(
788788 freqs_cis = pad_sequence (freqs_cis , batch_first = True , padding_value = 0.0 )[:, : feats .shape [1 ]]
789789
790790 # Attention mask
791- attn_mask = torch .zeros ((bsz , max_seqlen ), dtype = torch .bool , device = device )
792- for i , seq_len in enumerate (item_seqlens ):
793- attn_mask [i , :seq_len ] = 1
791+ if all (seq == max_seqlen for seq in item_seqlens ):
792+ attn_mask = None
793+ else :
794+ attn_mask = torch .zeros ((bsz , max_seqlen ), dtype = torch .bool , device = device )
795+ for i , seq_len in enumerate (item_seqlens ):
796+ attn_mask [i , :seq_len ] = 1
794797
795798 # Noise mask
796799 noise_mask_tensor = None
@@ -871,9 +874,12 @@ def _build_unified_sequence(
871874 unified_freqs = pad_sequence (unified_freqs , batch_first = True , padding_value = 0.0 )
872875
873876 # Attention mask
874- attn_mask = torch .zeros ((bsz , max_seqlen ), dtype = torch .bool , device = device )
875- for i , seq_len in enumerate (unified_seqlens ):
876- attn_mask [i , :seq_len ] = 1
877+ if all (seq == max_seqlen for seq in unified_seqlens ):
878+ attn_mask = None
879+ else :
880+ attn_mask = torch .zeros ((bsz , max_seqlen ), dtype = torch .bool , device = device )
881+ for i , seq_len in enumerate (unified_seqlens ):
882+ attn_mask [i , :seq_len ] = 1
877883
878884 # Noise mask
879885 noise_mask_tensor = None
0 commit comments