|
16 | 16 | MaskFnAttentionBias, |
17 | 17 | SegmentIdAttentionBias, |
18 | 18 | TensorAttentionBias, |
| 19 | + sliding_window_causal_mask, |
19 | 20 | ) |
20 | 21 | from axlearn.common.utils import Tensor |
21 | 22 |
|
22 | 23 |
|
| 24 | +class MaskTest(test_utils.TestCase): |
| 25 | + @parameterized.parameters( |
| 26 | + [0, [[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 1]]], |
| 27 | + [2, [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [0, 1, 1, 1, 0], [0, 0, 1, 1, 1]]], |
| 28 | + [4, [[1, 0, 0, 0, 0], [1, 1, 0, 0, 0], [1, 1, 1, 0, 0], [1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]], |
| 29 | + ) |
| 30 | + def test_sliding_window_mask(self, left_context, expected): |
| 31 | + mask_fn = sliding_window_causal_mask(sliding_window_size=left_context) |
| 32 | + step_len = 5 |
| 33 | + target_positions = jnp.arange(step_len)[:, None] |
| 34 | + source_positions = jnp.arange(step_len)[None, :] |
| 35 | + bool_mask = mask_fn(target_positions, source_positions) |
| 36 | + out_mask = bool_mask.astype(jnp.int32) |
| 37 | + self.assertEqual(out_mask.tolist(), expected) |
| 38 | + |
| 39 | + |
23 | 40 | class AttentionBiasTest(test_utils.TestCase): |
24 | 41 | @parameterized.parameters( |
25 | 42 | [attention_bias.ZeroAttentionBias(), False], |
|
0 commit comments