3333 final ,
3434)
3535
36- import einops
3736import jax
3837from jax import numpy as jnp
3938from jax .sharding import PartitionSpec
4039
4140from axlearn .common import struct
42- from axlearn .common .config import ClassConfigBase , ConfigOr , config_for_class , maybe_instantiate
41+ from axlearn .common .config import ConfigOr , maybe_instantiate
4342from axlearn .common .utils import Tensor
4443
4544NEG_INF = - 1e15
@@ -439,16 +438,10 @@ def partition_spec(
439438 self , mha_dim_to_partition_spec : dict [str , PartitionSpec ]
440439 ) -> Union [BaseAttentionBias , PartitionSpec ]:
441440 # Segment IDs: [batch_size, seq_len].
442- # We use the partition spec of KV (which are not sequence sharded) for segment ids. This is
443- # because Splash requires two seq ids, q_seg and kv_seg. Therefore, we pass a not seq
444- # sharded seg ids into the shard map, and manually shard it inside for q_seg and not
445- # shard it for kv_seg.
446- kv_spec = mha_dim_to_partition_spec ["bsnh" ]
447- if kv_spec == PartitionSpec (None ):
441+ q_spec = mha_dim_to_partition_spec ["btnh" ]
442+ if q_spec == PartitionSpec (None ):
448443 return PartitionSpec (None )
449- if kv_spec [1 ] is not None :
450- raise ValueError ("The partition spec of `s` in `bsnh` should be None." )
451- return PartitionSpec (kv_spec [0 ], kv_spec [1 ])
444+ return PartitionSpec (q_spec [0 ], q_spec [1 ])
452445
453446
454447class MaskFn (Protocol ):
@@ -496,25 +489,20 @@ class MaskFnAttentionBias(BoolAttentionBias):
496489
497490 # The function defining the contents of the mask.
498491 mask : MaskFn = struct .field (pytree_node = False )
499-
492+ # The shape [target_len, source_len] of the mask.
493+ shape : tuple [int , ...] = struct .field (kw_only = True , pytree_node = False )
500494 # The positions in the query sequence that the mask should be computed for.
501495 # I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at
502496 # `target_positions[batch, i]` may attend to.
497+ # If None, set `target_positions[batch, i] = i`.
498+ # Shape: [batch] or [batch, target_len]`.
503499 # This is typically used during decoding to specify the locations in the sequence being
504- # being decoded.
505- # E.g., if we are decoding position 5 and 7 of the first and second batch entry respectively,
506- # we would set `target_positions = jnp.arange(steps)[None] + jnp.asarray([5, 7])`.
500+ # being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch
501+ # entry respectively, we would set `target_positions = jnp.asarray([5, 7])`.
507502 # The motivation for supporting such shapes is for use cases where time_step in transformers
508503 # is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts,
509504 # various papers that need it.
510- # The index in the sequence of query vectors, [1|batch, target_len].
511- target_positions : Tensor = struct .field (kw_only = True )
512- # The index in the sequence of key vectors, [1|batch, source_len].
513- source_positions : Tensor = struct .field (kw_only = True )
514-
515- @classmethod
516- def default_config (cls , mask : MaskFn ) -> ClassConfigBase ["MaskFnAttentionBias" ]:
517- return config_for_class (MaskFnAttentionBias ).set (mask = mask )
505+ target_positions : Optional [Tensor ] = None
518506
519507 def _bool_value (self ) -> Optional [Tensor ]:
520508 """Return a tensor with the boolean values from `self.mask` before they have been converted
@@ -523,15 +511,29 @@ def _bool_value(self) -> Optional[Tensor]:
523511 Shape: [batch, target_len, source_len].
524512
525513 Raises:
526- ValueError . If `(target|source)_positions .ndim not == 2 `.
514+ NotImplementedError . If `target_positions .ndim not in [1,2] `.
527515 """
528- target_positions , source_positions = self .target_positions , self .source_positions
529- if target_positions .ndim != source_positions .ndim != 2 :
530- raise ValueError (
531- f"{ target_positions .shape = } or { source_positions .shape = } is not rank 2."
532- )
533- target_positions = einops .rearrange (target_positions , "b t -> b t 1" )
534- source_positions = einops .rearrange (source_positions , "b s -> b 1 s" )
516+ target_positions , source_positions = jnp .indices (self .shape , sparse = True )
517+ # Shape: [1, target_len, 1], [1, 1, source_len].
518+ target_positions , source_positions = target_positions [None ], source_positions [None ]
519+ if self .target_positions is not None :
520+ target_positions = self .target_positions
521+ if target_positions .ndim not in [1 , 2 ]:
522+ raise NotImplementedError (f"Shape of target_positions: { target_positions .shape } ." )
523+ if target_positions .ndim == 1 :
524+ # Shape: [batch, 1] + [target_len] = [batch, target_len]
525+ # pylint: disable-next=unsubscriptable-object
526+ target_positions = target_positions [:, None ] + jnp .arange (self .shape [0 ])
527+ elif target_positions .ndim == 2 :
528+ shape_with_batch_dim = (1 , * self .shape )
529+ # Raise an exception if shapes aren't compatible. We don't use the output.
530+ jnp .broadcast_shapes (
531+ (target_positions .shape [0 ], 1 , target_positions .shape [1 ]), shape_with_batch_dim
532+ )
533+ else :
534+ raise NotImplementedError (f"Invalid value { target_positions .ndim = } ." )
535+ target_positions = target_positions [..., None ] # Shape: [batch, target_len, 1].
536+
535537 return self .mask (target_positions , source_positions ) # pylint: disable=not-callable
536538
537539 @classmethod
@@ -554,26 +556,20 @@ def from_sequence(
554556 return super ().from_sequence (biases )
555557 except NotImplementedError :
556558 pass
559+ for bias in biases :
560+ if bias .target_positions is not None :
561+ raise ValueError (f"target_positions was not None for { bias } ." )
557562
558563 # Combine masks.
559564 mask = lambda query_position , key_position : jnp .all (
560565 jnp .stack ([b .mask (query_position , key_position ) for b in biases ]), axis = 0
561566 )
562- return MaskFnAttentionBias (
563- mask = mask ,
564- target_positions = biases [0 ].target_positions ,
565- source_positions = biases [0 ].source_positions ,
566- )
567+ return MaskFnAttentionBias (mask = mask , shape = biases [0 ].shape )
567568
568569 def partition_spec (
569570 self , mha_dim_to_partition_spec : dict [str , PartitionSpec ]
570571 ) -> Union [BaseAttentionBias , PartitionSpec ]:
571- batch = mha_dim_to_partition_spec ["bnts" ][0 ]
572- return dataclasses .replace (
573- self ,
574- target_positions = PartitionSpec (None if self .target_positions .shape [0 ] == 1 else batch ),
575- source_positions = PartitionSpec (None if self .source_positions .shape [0 ] == 1 else batch ),
576- )
572+ return PartitionSpec (* mha_dim_to_partition_spec ["bnts" ][0 :1 ])
577573
578574
579575@struct .dataclass
@@ -644,10 +640,6 @@ class CausalAttentionBias(MaskFnAttentionBias): # pylint: disable=final-error
644640
645641 mask : Optional [MaskFn ] = struct .field (pytree_node = False , default = causal_mask )
646642
647- @classmethod
648- def default_config (cls ) -> ClassConfigBase [MaskFnAttentionBias ]:
649- return config_for_class (CausalAttentionBias )
650-
651643 @classmethod
652644 def from_sequence (
653645 cls , biases : Sequence ["CausalAttentionBias" ]
@@ -659,23 +651,6 @@ def from_sequence(
659651 return biases [0 ]
660652
661653
662- @struct .dataclass
663- @final
664- class SlidingWindowAttentionBias (MaskFnAttentionBias ): # pylint: disable=final-error
665- """A sliding window attention mask."""
666-
667- # A left context size for sliding window attention. sliding window size = left context + 1.
668- left_context : int = struct .field (kw_only = True , pytree_node = False )
669-
670- @classmethod
671- # pylint: disable-next=arguments-renamed
672- def default_config (cls , left_context : int ) -> ClassConfigBase [MaskFnAttentionBias ]:
673- return config_for_class (SlidingWindowAttentionBias ).set (
674- mask = sliding_window_causal_mask (left_context = left_context ),
675- left_context = left_context ,
676- )
677-
678-
679654@struct .dataclass
680655@final
681656class ZeroAttentionBias (BoolAttentionBias ):
@@ -722,19 +697,19 @@ def and_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn:
722697 return _composite_masks (jnp .logical_and , * mask_fns )
723698
724699
725- def sliding_window_causal_mask (left_context : int ) -> MaskFn :
700+ def sliding_window_causal_mask (sliding_window_size : int ) -> MaskFn :
726701 """Returns a causal MaskFn for sliding window attentions of a given window size.
727702
728703 Implements the `MaskFn` protocol.
729704 """
730705
731706 def mask (query_position : Tensor , key_position : Tensor ):
732- pos_mask = query_position - key_position <= left_context
733- # Negative positions indicate prefill padding.
734- key_valid = key_position >= 0
735- return pos_mask & key_valid
707+ return query_position - key_position <= sliding_window_size
736708
737709 fun = and_masks (causal_mask , mask )
710+ # Flash attention needs to recognize sliding window size in _to_splash_mask().
711+ # pylint: disable-next=protected-access
712+ fun ._sliding_window_size = sliding_window_size
738713 return fun
739714
740715
@@ -752,17 +727,17 @@ def make_causal_biases(seq_len: int) -> Tensor:
752727 return bool_to_bias (causal_mask (jnp .arange (seq_len )[:, None ], jnp .arange (seq_len )[None , :]))
753728
754729
755- def make_sliding_window_causal_biases (seq_len : int , left_context : int ) -> Tensor :
730+ def make_sliding_window_causal_biases (seq_len : int , sliding_window_size : int ) -> Tensor :
756731 """Generates attention logit biases for sliding window attention.
757732
758733 Args:
759734 seq_len: Sequence length.
760735
761736 Returns:
762737 A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf
763- if i - j > left_context or i < j, 0 otherwise.
738+ if i - j > sliding_window_size or i < j, 0 otherwise.
764739 """
765- mask_fn = sliding_window_causal_mask (left_context )
740+ mask_fn = sliding_window_causal_mask (sliding_window_size )
766741 return bool_to_bias (mask_fn (jnp .arange (seq_len )[:, None ], jnp .arange (seq_len )[None , :]))
767742
768743
0 commit comments