33
33
final ,
34
34
)
35
35
36
- import einops
37
36
import jax
38
37
from jax import numpy as jnp
39
38
from jax .sharding import PartitionSpec
40
39
41
40
from axlearn .common import struct
42
- from axlearn .common .config import ClassConfigBase , ConfigOr , config_for_class , maybe_instantiate
41
+ from axlearn .common .config import ConfigOr , maybe_instantiate
43
42
from axlearn .common .utils import Tensor
44
43
45
44
NEG_INF = - 1e15
@@ -439,16 +438,10 @@ def partition_spec(
439
438
self , mha_dim_to_partition_spec : dict [str , PartitionSpec ]
440
439
) -> Union [BaseAttentionBias , PartitionSpec ]:
441
440
# Segment IDs: [batch_size, seq_len].
442
- # We use the partition spec of KV (which are not sequence sharded) for segment ids. This is
443
- # because Splash requires two seq ids, q_seg and kv_seg. Therefore, we pass a not seq
444
- # sharded seg ids into the shard map, and manually shard it inside for q_seg and not
445
- # shard it for kv_seg.
446
- kv_spec = mha_dim_to_partition_spec ["bsnh" ]
447
- if kv_spec == PartitionSpec (None ):
441
+ q_spec = mha_dim_to_partition_spec ["btnh" ]
442
+ if q_spec == PartitionSpec (None ):
448
443
return PartitionSpec (None )
449
- if kv_spec [1 ] is not None :
450
- raise ValueError ("The partition spec of `s` in `bsnh` should be None." )
451
- return PartitionSpec (kv_spec [0 ], kv_spec [1 ])
444
+ return PartitionSpec (q_spec [0 ], q_spec [1 ])
452
445
453
446
454
447
class MaskFn (Protocol ):
@@ -496,25 +489,20 @@ class MaskFnAttentionBias(BoolAttentionBias):
496
489
497
490
# The function defining the contents of the mask.
498
491
mask : MaskFn = struct .field (pytree_node = False )
499
-
492
+ # The shape [target_len, source_len] of the mask.
493
+ shape : tuple [int , ...] = struct .field (kw_only = True , pytree_node = False )
500
494
# The positions in the query sequence that the mask should be computed for.
501
495
# I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at
502
496
# `target_positions[batch, i]` may attend to.
497
+ # If None, set `target_positions[batch, i] = i`.
498
+ # Shape: [batch] or [batch, target_len]`.
503
499
# This is typically used during decoding to specify the locations in the sequence being
504
- # being decoded.
505
- # E.g., if we are decoding position 5 and 7 of the first and second batch entry respectively,
506
- # we would set `target_positions = jnp.arange(steps)[None] + jnp.asarray([5, 7])`.
500
+ # being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch
501
+ # entry respectively, we would set `target_positions = jnp.asarray([5, 7])`.
507
502
# The motivation for supporting such shapes is for use cases where time_step in transformers
508
503
# is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts,
509
504
# various papers that need it.
510
- # The index in the sequence of query vectors, [1|batch, target_len].
511
- target_positions : Tensor = struct .field (kw_only = True )
512
- # The index in the sequence of key vectors, [1|batch, source_len].
513
- source_positions : Tensor = struct .field (kw_only = True )
514
-
515
- @classmethod
516
- def default_config (cls , mask : MaskFn ) -> ClassConfigBase ["MaskFnAttentionBias" ]:
517
- return config_for_class (MaskFnAttentionBias ).set (mask = mask )
505
+ target_positions : Optional [Tensor ] = None
518
506
519
507
def _bool_value (self ) -> Optional [Tensor ]:
520
508
"""Return a tensor with the boolean values from `self.mask` before they have been converted
@@ -523,15 +511,29 @@ def _bool_value(self) -> Optional[Tensor]:
523
511
Shape: [batch, target_len, source_len].
524
512
525
513
Raises:
526
- ValueError . If `(target|source)_positions .ndim not == 2 `.
514
+ NotImplementedError . If `target_positions .ndim not in [1,2] `.
527
515
"""
528
- target_positions , source_positions = self .target_positions , self .source_positions
529
- if target_positions .ndim != source_positions .ndim != 2 :
530
- raise ValueError (
531
- f"{ target_positions .shape = } or { source_positions .shape = } is not rank 2."
532
- )
533
- target_positions = einops .rearrange (target_positions , "b t -> b t 1" )
534
- source_positions = einops .rearrange (source_positions , "b s -> b 1 s" )
516
+ target_positions , source_positions = jnp .indices (self .shape , sparse = True )
517
+ # Shape: [1, target_len, 1], [1, 1, source_len].
518
+ target_positions , source_positions = target_positions [None ], source_positions [None ]
519
+ if self .target_positions is not None :
520
+ target_positions = self .target_positions
521
+ if target_positions .ndim not in [1 , 2 ]:
522
+ raise NotImplementedError (f"Shape of target_positions: { target_positions .shape } ." )
523
+ if target_positions .ndim == 1 :
524
+ # Shape: [batch, 1] + [target_len] = [batch, target_len]
525
+ # pylint: disable-next=unsubscriptable-object
526
+ target_positions = target_positions [:, None ] + jnp .arange (self .shape [0 ])
527
+ elif target_positions .ndim == 2 :
528
+ shape_with_batch_dim = (1 , * self .shape )
529
+ # Raise an exception if shapes aren't compatible. We don't use the output.
530
+ jnp .broadcast_shapes (
531
+ (target_positions .shape [0 ], 1 , target_positions .shape [1 ]), shape_with_batch_dim
532
+ )
533
+ else :
534
+ raise NotImplementedError (f"Invalid value { target_positions .ndim = } ." )
535
+ target_positions = target_positions [..., None ] # Shape: [batch, target_len, 1].
536
+
535
537
return self .mask (target_positions , source_positions ) # pylint: disable=not-callable
536
538
537
539
@classmethod
@@ -554,26 +556,20 @@ def from_sequence(
554
556
return super ().from_sequence (biases )
555
557
except NotImplementedError :
556
558
pass
559
+ for bias in biases :
560
+ if bias .target_positions is not None :
561
+ raise ValueError (f"target_positions was not None for { bias } ." )
557
562
558
563
# Combine masks.
559
564
mask = lambda query_position , key_position : jnp .all (
560
565
jnp .stack ([b .mask (query_position , key_position ) for b in biases ]), axis = 0
561
566
)
562
- return MaskFnAttentionBias (
563
- mask = mask ,
564
- target_positions = biases [0 ].target_positions ,
565
- source_positions = biases [0 ].source_positions ,
566
- )
567
+ return MaskFnAttentionBias (mask = mask , shape = biases [0 ].shape )
567
568
568
569
def partition_spec (
569
570
self , mha_dim_to_partition_spec : dict [str , PartitionSpec ]
570
571
) -> Union [BaseAttentionBias , PartitionSpec ]:
571
- batch = mha_dim_to_partition_spec ["bnts" ][0 ]
572
- return dataclasses .replace (
573
- self ,
574
- target_positions = PartitionSpec (None if self .target_positions .shape [0 ] == 1 else batch ),
575
- source_positions = PartitionSpec (None if self .source_positions .shape [0 ] == 1 else batch ),
576
- )
572
+ return PartitionSpec (* mha_dim_to_partition_spec ["bnts" ][0 :1 ])
577
573
578
574
579
575
@struct .dataclass
@@ -644,10 +640,6 @@ class CausalAttentionBias(MaskFnAttentionBias): # pylint: disable=final-error
644
640
645
641
mask : Optional [MaskFn ] = struct .field (pytree_node = False , default = causal_mask )
646
642
647
- @classmethod
648
- def default_config (cls ) -> ClassConfigBase [MaskFnAttentionBias ]:
649
- return config_for_class (CausalAttentionBias )
650
-
651
643
@classmethod
652
644
def from_sequence (
653
645
cls , biases : Sequence ["CausalAttentionBias" ]
@@ -659,23 +651,6 @@ def from_sequence(
659
651
return biases [0 ]
660
652
661
653
662
- @struct .dataclass
663
- @final
664
- class SlidingWindowAttentionBias (MaskFnAttentionBias ): # pylint: disable=final-error
665
- """A sliding window attention mask."""
666
-
667
- # A left context size for sliding window attention. sliding window size = left context + 1.
668
- left_context : int = struct .field (kw_only = True , pytree_node = False )
669
-
670
- @classmethod
671
- # pylint: disable-next=arguments-renamed
672
- def default_config (cls , left_context : int ) -> ClassConfigBase [MaskFnAttentionBias ]:
673
- return config_for_class (SlidingWindowAttentionBias ).set (
674
- mask = sliding_window_causal_mask (left_context = left_context ),
675
- left_context = left_context ,
676
- )
677
-
678
-
679
654
@struct .dataclass
680
655
@final
681
656
class ZeroAttentionBias (BoolAttentionBias ):
@@ -722,19 +697,19 @@ def and_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn:
722
697
return _composite_masks (jnp .logical_and , * mask_fns )
723
698
724
699
725
- def sliding_window_causal_mask (left_context : int ) -> MaskFn :
700
+ def sliding_window_causal_mask (sliding_window_size : int ) -> MaskFn :
726
701
"""Returns a causal MaskFn for sliding window attentions of a given window size.
727
702
728
703
Implements the `MaskFn` protocol.
729
704
"""
730
705
731
706
def mask (query_position : Tensor , key_position : Tensor ):
732
- pos_mask = query_position - key_position <= left_context
733
- # Negative positions indicate prefill padding.
734
- key_valid = key_position >= 0
735
- return pos_mask & key_valid
707
+ return query_position - key_position <= sliding_window_size
736
708
737
709
fun = and_masks (causal_mask , mask )
710
+ # Flash attention needs to recognize sliding window size in _to_splash_mask().
711
+ # pylint: disable-next=protected-access
712
+ fun ._sliding_window_size = sliding_window_size
738
713
return fun
739
714
740
715
@@ -752,17 +727,17 @@ def make_causal_biases(seq_len: int) -> Tensor:
752
727
return bool_to_bias (causal_mask (jnp .arange (seq_len )[:, None ], jnp .arange (seq_len )[None , :]))
753
728
754
729
755
- def make_sliding_window_causal_biases (seq_len : int , left_context : int ) -> Tensor :
730
+ def make_sliding_window_causal_biases (seq_len : int , sliding_window_size : int ) -> Tensor :
756
731
"""Generates attention logit biases for sliding window attention.
757
732
758
733
Args:
759
734
seq_len: Sequence length.
760
735
761
736
Returns:
762
737
A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf
763
- if i - j > left_context or i < j, 0 otherwise.
738
+ if i - j > sliding_window_size or i < j, 0 otherwise.
764
739
"""
765
- mask_fn = sliding_window_causal_mask (left_context )
740
+ mask_fn = sliding_window_causal_mask (sliding_window_size )
766
741
return bool_to_bias (mask_fn (jnp .arange (seq_len )[:, None ], jnp .arange (seq_len )[None , :]))
767
742
768
743
0 commit comments