@@ -405,8 +405,9 @@ def forward(
405405 shift : Optional [Tensor ] = None ,
406406 scale : Optional [Tensor ] = None ,
407407 gate : Optional [Tensor ] = None ,
408- query_positions : Optional [Tensor ] = None ,
409408 attention_logit_biases : Optional [Tensor ] = None ,
409+ segment_ids : Optional [Tensor ] = None ,
410+ query_positions : Optional [Tensor ] = None ,
410411 ) -> Tensor :
411412 """The forward function of DiTAttentionLayer.
412413
@@ -418,7 +419,12 @@ def forward(
418419 target_dim] and shift should be provided.
419420 gate: If provided, applying before the residual addition with shape
420421 [batch_size, 1|num_length, target_dim].
421- attention_logit_biases: Optional Tensor representing the self attention biases.
422+ attention_logit_biases: Optional Tensor representing the self attention biases with
423+ shape [batch_size, num_length, num_length].
424+ segment_ids: Optional int Tensor representing the segment each token belongs to with
425+ shape [batch_size, num_length].
426+ query_positions: Optional Tensor representing the query positions when computing the
427+ attention with shape [batch_size, num_length]
422428
423429 Returns:
424430 A tensor with shape [batch_size, num_length, target_dim].
@@ -442,7 +448,10 @@ def forward(
442448 x = modulate (x = x , shift = shift , scale = scale )
443449
444450 x = self .attention (
445- query = x , query_positions = query_positions , attention_logit_biases = attention_logit_biases
451+ query = x ,
452+ attention_logit_biases = attention_logit_biases ,
453+ segment_ids = segment_ids ,
454+ query_positions = query_positions ,
446455 ).data
447456
448457 if cfg .structure == "postnorm" :
0 commit comments