Skip to content

Commit 8fd9137

Browse files
authored
Add segment_ids option in DiTAttentionLayer (apple#976)
1 parent e55a404 commit 8fd9137

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

axlearn/common/dit.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,9 @@ def forward(
405405
shift: Optional[Tensor] = None,
406406
scale: Optional[Tensor] = None,
407407
gate: Optional[Tensor] = None,
408-
query_positions: Optional[Tensor] = None,
409408
attention_logit_biases: Optional[Tensor] = None,
409+
segment_ids: Optional[Tensor] = None,
410+
query_positions: Optional[Tensor] = None,
410411
) -> Tensor:
411412
"""The forward function of DiTAttentionLayer.
412413
@@ -418,7 +419,12 @@ def forward(
418419
target_dim] and shift should be provided.
419420
gate: If provided, applying before the residual addition with shape
420421
[batch_size, 1|num_length, target_dim].
421-
attention_logit_biases: Optional Tensor representing the self attention biases.
422+
attention_logit_biases: Optional Tensor representing the self attention biases with
423+
shape [batch_size, num_length, num_length].
424+
segment_ids: Optional int Tensor representing the segment each token belongs to with
425+
shape [batch_size, num_length].
426+
query_positions: Optional Tensor representing the query positions when computing the
427+
attention with shape [batch_size, num_length]
422428
423429
Returns:
424430
A tensor with shape [batch_size, num_length, target_dim].
@@ -442,7 +448,10 @@ def forward(
442448
x = modulate(x=x, shift=shift, scale=scale)
443449

444450
x = self.attention(
445-
query=x, query_positions=query_positions, attention_logit_biases=attention_logit_biases
451+
query=x,
452+
attention_logit_biases=attention_logit_biases,
453+
segment_ids=segment_ids,
454+
query_positions=query_positions,
446455
).data
447456

448457
if cfg.structure == "postnorm":

axlearn/common/dit_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,44 @@ def test_dit_attn_logit_biases(self):
528528
# Expect the output be the same for valid items because of logit_biases.
529529
assert_allclose(layer_output * valid_mask, layer_output2 * valid_mask)
530530

531+
def test_dit_attn_segment_ids(self):
532+
batch_size = 2
533+
seq_len = 3
534+
dim = 32
535+
num_heads = 2
536+
537+
prng_key = jax.random.PRNGKey(0)
538+
inputs = jax.random.normal(prng_key, shape=(batch_size, seq_len, dim))
539+
shift = jax.random.normal(prng_key, shape=(batch_size, 1, dim))
540+
scale = jax.random.normal(prng_key, shape=(batch_size, 1, dim))
541+
gate = jax.random.normal(prng_key, shape=(batch_size, 1, dim))
542+
segment_ids = jnp.ones((batch_size, seq_len))
543+
544+
layer_cfg = DiTAttentionLayer.default_config().set(
545+
name="test",
546+
source_dim=dim,
547+
target_dim=dim,
548+
)
549+
layer_cfg.attention.num_heads = num_heads
550+
layer_cfg.norm.eps = 1e-6
551+
layer = layer_cfg.instantiate(parent=None)
552+
state = layer.initialize_parameters_recursively(prng_key=prng_key)
553+
554+
layer_output, _ = F(
555+
layer,
556+
inputs=dict(
557+
input=inputs,
558+
shift=shift,
559+
scale=scale,
560+
gate=gate,
561+
segment_ids=segment_ids,
562+
),
563+
state=state,
564+
is_training=False,
565+
prng_key=prng_key,
566+
)
567+
assert_allclose(layer_output.shape, inputs.shape)
568+
531569
@parameterized.parameters([True, False])
532570
def test_dit_attn_optional_input(self, use_ssg):
533571
batch_size = 2

0 commit comments

Comments
 (0)