Skip to content

Commit 334f421

Browse files
author
Mark Lee
authored
Reverts sliding window attention changes. (#1004)
* Revert "Fix flash decoding in GPU. (#999)" This reverts commit fdadfd8. * Revert "Supports TPU context parallel training (#981)" This reverts commit e151d69. * Revert "Implemented sliding window attention to maintain KV cache only for the window size to enable infinite decoding. (#995)" This reverts commit 67645d0. * Retain model/decoder asr changes.
1 parent 3dacc6b commit 334f421

File tree

108 files changed

+1649
-2238
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+1649
-2238
lines changed

axlearn/common/attention.py

Lines changed: 102 additions & 443 deletions
Large diffs are not rendered by default.

axlearn/common/attention_bias.py

Lines changed: 46 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,12 @@
3333
final,
3434
)
3535

36-
import einops
3736
import jax
3837
from jax import numpy as jnp
3938
from jax.sharding import PartitionSpec
4039

4140
from axlearn.common import struct
42-
from axlearn.common.config import ClassConfigBase, ConfigOr, config_for_class, maybe_instantiate
41+
from axlearn.common.config import ConfigOr, maybe_instantiate
4342
from axlearn.common.utils import Tensor
4443

4544
NEG_INF = -1e15
@@ -439,16 +438,10 @@ def partition_spec(
439438
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
440439
) -> Union[BaseAttentionBias, PartitionSpec]:
441440
# Segment IDs: [batch_size, seq_len].
442-
# We use the partition spec of KV (which are not sequence sharded) for segment ids. This is
443-
# because Splash requires two seq ids, q_seg and kv_seg. Therefore, we pass a not seq
444-
# sharded seg ids into the shard map, and manually shard it inside for q_seg and not
445-
# shard it for kv_seg.
446-
kv_spec = mha_dim_to_partition_spec["bsnh"]
447-
if kv_spec == PartitionSpec(None):
441+
q_spec = mha_dim_to_partition_spec["btnh"]
442+
if q_spec == PartitionSpec(None):
448443
return PartitionSpec(None)
449-
if kv_spec[1] is not None:
450-
raise ValueError("The partition spec of `s` in `bsnh` should be None.")
451-
return PartitionSpec(kv_spec[0], kv_spec[1])
444+
return PartitionSpec(q_spec[0], q_spec[1])
452445

453446

454447
class MaskFn(Protocol):
@@ -496,25 +489,20 @@ class MaskFnAttentionBias(BoolAttentionBias):
496489

497490
# The function defining the contents of the mask.
498491
mask: MaskFn = struct.field(pytree_node=False)
499-
492+
# The shape [target_len, source_len] of the mask.
493+
shape: tuple[int, ...] = struct.field(kw_only=True, pytree_node=False)
500494
# The positions in the query sequence that the mask should be computed for.
501495
# I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at
502496
# `target_positions[batch, i]` may attend to.
497+
# If None, set `target_positions[batch, i] = i`.
498+
# Shape: [batch] or [batch, target_len]`.
503499
# This is typically used during decoding to specify the locations in the sequence being
504-
# being decoded.
505-
# E.g., if we are decoding position 5 and 7 of the first and second batch entry respectively,
506-
# we would set `target_positions = jnp.arange(steps)[None] + jnp.asarray([5, 7])`.
500+
# being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch
501+
# entry respectively, we would set `target_positions = jnp.asarray([5, 7])`.
507502
# The motivation for supporting such shapes is for use cases where time_step in transformers
508503
# is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts,
509504
# various papers that need it.
510-
# The index in the sequence of query vectors, [1|batch, target_len].
511-
target_positions: Tensor = struct.field(kw_only=True)
512-
# The index in the sequence of key vectors, [1|batch, source_len].
513-
source_positions: Tensor = struct.field(kw_only=True)
514-
515-
@classmethod
516-
def default_config(cls, mask: MaskFn) -> ClassConfigBase["MaskFnAttentionBias"]:
517-
return config_for_class(MaskFnAttentionBias).set(mask=mask)
505+
target_positions: Optional[Tensor] = None
518506

519507
def _bool_value(self) -> Optional[Tensor]:
520508
"""Return a tensor with the boolean values from `self.mask` before they have been converted
@@ -523,15 +511,29 @@ def _bool_value(self) -> Optional[Tensor]:
523511
Shape: [batch, target_len, source_len].
524512
525513
Raises:
526-
ValueError. If `(target|source)_positions.ndim not == 2`.
514+
NotImplementedError. If `target_positions.ndim not in [1,2]`.
527515
"""
528-
target_positions, source_positions = self.target_positions, self.source_positions
529-
if target_positions.ndim != source_positions.ndim != 2:
530-
raise ValueError(
531-
f"{target_positions.shape=} or {source_positions.shape=} is not rank 2."
532-
)
533-
target_positions = einops.rearrange(target_positions, "b t -> b t 1")
534-
source_positions = einops.rearrange(source_positions, "b s -> b 1 s")
516+
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
517+
# Shape: [1, target_len, 1], [1, 1, source_len].
518+
target_positions, source_positions = target_positions[None], source_positions[None]
519+
if self.target_positions is not None:
520+
target_positions = self.target_positions
521+
if target_positions.ndim not in [1, 2]:
522+
raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.")
523+
if target_positions.ndim == 1:
524+
# Shape: [batch, 1] + [target_len] = [batch, target_len]
525+
# pylint: disable-next=unsubscriptable-object
526+
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
527+
elif target_positions.ndim == 2:
528+
shape_with_batch_dim = (1, *self.shape)
529+
# Raise an exception if shapes aren't compatible. We don't use the output.
530+
jnp.broadcast_shapes(
531+
(target_positions.shape[0], 1, target_positions.shape[1]), shape_with_batch_dim
532+
)
533+
else:
534+
raise NotImplementedError(f"Invalid value {target_positions.ndim=}.")
535+
target_positions = target_positions[..., None] # Shape: [batch, target_len, 1].
536+
535537
return self.mask(target_positions, source_positions) # pylint: disable=not-callable
536538

537539
@classmethod
@@ -554,26 +556,20 @@ def from_sequence(
554556
return super().from_sequence(biases)
555557
except NotImplementedError:
556558
pass
559+
for bias in biases:
560+
if bias.target_positions is not None:
561+
raise ValueError(f"target_positions was not None for {bias}.")
557562

558563
# Combine masks.
559564
mask = lambda query_position, key_position: jnp.all(
560565
jnp.stack([b.mask(query_position, key_position) for b in biases]), axis=0
561566
)
562-
return MaskFnAttentionBias(
563-
mask=mask,
564-
target_positions=biases[0].target_positions,
565-
source_positions=biases[0].source_positions,
566-
)
567+
return MaskFnAttentionBias(mask=mask, shape=biases[0].shape)
567568

568569
def partition_spec(
569570
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
570571
) -> Union[BaseAttentionBias, PartitionSpec]:
571-
batch = mha_dim_to_partition_spec["bnts"][0]
572-
return dataclasses.replace(
573-
self,
574-
target_positions=PartitionSpec(None if self.target_positions.shape[0] == 1 else batch),
575-
source_positions=PartitionSpec(None if self.source_positions.shape[0] == 1 else batch),
576-
)
572+
return PartitionSpec(*mha_dim_to_partition_spec["bnts"][0:1])
577573

578574

579575
@struct.dataclass
@@ -644,10 +640,6 @@ class CausalAttentionBias(MaskFnAttentionBias): # pylint: disable=final-error
644640

645641
mask: Optional[MaskFn] = struct.field(pytree_node=False, default=causal_mask)
646642

647-
@classmethod
648-
def default_config(cls) -> ClassConfigBase[MaskFnAttentionBias]:
649-
return config_for_class(CausalAttentionBias)
650-
651643
@classmethod
652644
def from_sequence(
653645
cls, biases: Sequence["CausalAttentionBias"]
@@ -659,23 +651,6 @@ def from_sequence(
659651
return biases[0]
660652

661653

662-
@struct.dataclass
663-
@final
664-
class SlidingWindowAttentionBias(MaskFnAttentionBias): # pylint: disable=final-error
665-
"""A sliding window attention mask."""
666-
667-
# A left context size for sliding window attention. sliding window size = left context + 1.
668-
left_context: int = struct.field(kw_only=True, pytree_node=False)
669-
670-
@classmethod
671-
# pylint: disable-next=arguments-renamed
672-
def default_config(cls, left_context: int) -> ClassConfigBase[MaskFnAttentionBias]:
673-
return config_for_class(SlidingWindowAttentionBias).set(
674-
mask=sliding_window_causal_mask(left_context=left_context),
675-
left_context=left_context,
676-
)
677-
678-
679654
@struct.dataclass
680655
@final
681656
class ZeroAttentionBias(BoolAttentionBias):
@@ -722,19 +697,19 @@ def and_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn:
722697
return _composite_masks(jnp.logical_and, *mask_fns)
723698

724699

725-
def sliding_window_causal_mask(left_context: int) -> MaskFn:
700+
def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn:
726701
"""Returns a causal MaskFn for sliding window attentions of a given window size.
727702
728703
Implements the `MaskFn` protocol.
729704
"""
730705

731706
def mask(query_position: Tensor, key_position: Tensor):
732-
pos_mask = query_position - key_position <= left_context
733-
# Negative positions indicate prefill padding.
734-
key_valid = key_position >= 0
735-
return pos_mask & key_valid
707+
return query_position - key_position <= sliding_window_size
736708

737709
fun = and_masks(causal_mask, mask)
710+
# Flash attention needs to recognize sliding window size in _to_splash_mask().
711+
# pylint: disable-next=protected-access
712+
fun._sliding_window_size = sliding_window_size
738713
return fun
739714

740715

@@ -752,17 +727,17 @@ def make_causal_biases(seq_len: int) -> Tensor:
752727
return bool_to_bias(causal_mask(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]))
753728

754729

755-
def make_sliding_window_causal_biases(seq_len: int, left_context: int) -> Tensor:
730+
def make_sliding_window_causal_biases(seq_len: int, sliding_window_size: int) -> Tensor:
756731
"""Generates attention logit biases for sliding window attention.
757732
758733
Args:
759734
seq_len: Sequence length.
760735
761736
Returns:
762737
A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf
763-
if i - j > left_context or i < j, 0 otherwise.
738+
if i - j > sliding_window_size or i < j, 0 otherwise.
764739
"""
765-
mask_fn = sliding_window_causal_mask(left_context)
740+
mask_fn = sliding_window_causal_mask(sliding_window_size)
766741
return bool_to_bias(mask_fn(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]))
767742

768743

0 commit comments

Comments
 (0)