Skip to content

Commit eaa1bd5

Browse files
committed
Clarify setting sliding_window_size = 8 results in a window size of 9, including itself.
1 parent d27c562 commit eaa1bd5

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

axlearn/common/attention_bias.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,12 @@ def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn:
701701
"""Returns a causal MaskFn for sliding window attentions of a given window size.
702702
703703
Implements the `MaskFn` protocol.
704+
705+
Note: Setting sliding_window_size = 8 results in attending to 9 tokens - it attends to itself
706+
and sliding_window_size tokens to the left.
707+
708+
Args:
709+
sliding_window_size: Left context of sliding window mask.
704710
"""
705711

706712
def mask(query_position: Tensor, key_position: Tensor):
@@ -730,8 +736,12 @@ def make_causal_biases(seq_len: int) -> Tensor:
730736
def make_sliding_window_causal_biases(seq_len: int, sliding_window_size: int) -> Tensor:
731737
"""Generates attention logit biases for sliding window attention.
732738
739+
Note: Setting sliding_window_size = 8 results in attending to 9 tokens - it attends to itself
740+
and sliding_window_size tokens to the left.
741+
733742
Args:
734743
seq_len: Sequence length.
744+
sliding_window_size: Left context of sliding window mask.
735745
736746
Returns:
737747
A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf

axlearn/common/attention_bias_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,27 @@
1616
MaskFnAttentionBias,
1717
SegmentIdAttentionBias,
1818
TensorAttentionBias,
19+
sliding_window_causal_mask,
1920
)
2021
from axlearn.common.utils import Tensor
2122

2223

24+
class MaskTest(test_utils.TestCase):
25+
@parameterized.parameters(
26+
[0, [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]],
27+
[2, [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [0, 1, 1, 1, 0], [0, 0, 1, 1, 1]]],
28+
[4, [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]],
29+
)
30+
def test_sliding_window_mask(self, left_context, expected):
31+
mask_fn = sliding_window_causal_mask(sliding_window_size=left_context)
32+
step_len = 5
33+
target_positions = jnp.arange(step_len)[:, None]
34+
source_positions = jnp.arange(step_len)[None, :]
35+
bool_mask = mask_fn(target_positions, source_positions)
36+
out_mask = bool_mask.astype(jnp.int32)
37+
self.assertEqual(out_mask.tolist(), expected)
38+
39+
2340
class AttentionBiasTest(test_utils.TestCase):
2441
@parameterized.parameters(
2542
[attention_bias.ZeroAttentionBias(), False],

0 commit comments

Comments
 (0)