Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverts sliding window attention changes. #1004

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
545 changes: 102 additions & 443 deletions axlearn/common/attention.py

Large diffs are not rendered by default.

117 changes: 46 additions & 71 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@
final,
)

import einops
import jax
from jax import numpy as jnp
from jax.sharding import PartitionSpec

from axlearn.common import struct
from axlearn.common.config import ClassConfigBase, ConfigOr, config_for_class, maybe_instantiate
from axlearn.common.config import ConfigOr, maybe_instantiate
from axlearn.common.utils import Tensor

NEG_INF = -1e15
Expand Down Expand Up @@ -439,16 +438,10 @@ def partition_spec(
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
) -> Union[BaseAttentionBias, PartitionSpec]:
# Segment IDs: [batch_size, seq_len].
# We use the partition spec of KV (which are not sequence sharded) for segment ids. This is
# because Splash requires two seq ids, q_seg and kv_seg. Therefore, we pass a not seq
# sharded seg ids into the shard map, and manually shard it inside for q_seg and not
# shard it for kv_seg.
kv_spec = mha_dim_to_partition_spec["bsnh"]
if kv_spec == PartitionSpec(None):
q_spec = mha_dim_to_partition_spec["btnh"]
if q_spec == PartitionSpec(None):
return PartitionSpec(None)
if kv_spec[1] is not None:
raise ValueError("The partition spec of `s` in `bsnh` should be None.")
return PartitionSpec(kv_spec[0], kv_spec[1])
return PartitionSpec(q_spec[0], q_spec[1])


class MaskFn(Protocol):
Expand Down Expand Up @@ -496,25 +489,20 @@ class MaskFnAttentionBias(BoolAttentionBias):

# The function defining the contents of the mask.
mask: MaskFn = struct.field(pytree_node=False)

# The shape [target_len, source_len] of the mask.
shape: tuple[int, ...] = struct.field(kw_only=True, pytree_node=False)
# The positions in the query sequence that the mask should be computed for.
# I.e., `self.value()[batch, num_heads, i]` is the mask specifying what the query token at
# `target_positions[batch, i]` may attend to.
# If None, set `target_positions[batch, i] = i`.
# Shape: [batch] or [batch, target_len]`.
# This is typically used during decoding to specify the locations in the sequence being
# being decoded.
# E.g., if we are decoding position 5 and 7 of the first and second batch entry respectively,
# we would set `target_positions = jnp.arange(steps)[None] + jnp.asarray([5, 7])`.
# being decoded. E.g., if we are decoding position 5 and 7 of the first and second batch
# entry respectively, we would set `target_positions = jnp.asarray([5, 7])`.
# The motivation for supporting such shapes is for use cases where time_step in transformers
# is not necessarily contiguous. E.g., speculative decoding, non-contiguous prompts,
# various papers that need it.
# The index in the sequence of query vectors, [1|batch, target_len].
target_positions: Tensor = struct.field(kw_only=True)
# The index in the sequence of key vectors, [1|batch, source_len].
source_positions: Tensor = struct.field(kw_only=True)

@classmethod
def default_config(cls, mask: MaskFn) -> ClassConfigBase["MaskFnAttentionBias"]:
return config_for_class(MaskFnAttentionBias).set(mask=mask)
target_positions: Optional[Tensor] = None

def _bool_value(self) -> Optional[Tensor]:
"""Return a tensor with the boolean values from `self.mask` before they have been converted
Expand All @@ -523,15 +511,29 @@ def _bool_value(self) -> Optional[Tensor]:
Shape: [batch, target_len, source_len].

Raises:
ValueError. If `(target|source)_positions.ndim not == 2`.
NotImplementedError. If `target_positions.ndim not in [1,2]`.
"""
target_positions, source_positions = self.target_positions, self.source_positions
if target_positions.ndim != source_positions.ndim != 2:
raise ValueError(
f"{target_positions.shape=} or {source_positions.shape=} is not rank 2."
)
target_positions = einops.rearrange(target_positions, "b t -> b t 1")
source_positions = einops.rearrange(source_positions, "b s -> b 1 s")
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
# Shape: [1, target_len, 1], [1, 1, source_len].
target_positions, source_positions = target_positions[None], source_positions[None]
if self.target_positions is not None:
target_positions = self.target_positions
if target_positions.ndim not in [1, 2]:
raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.")
if target_positions.ndim == 1:
# Shape: [batch, 1] + [target_len] = [batch, target_len]
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
elif target_positions.ndim == 2:
shape_with_batch_dim = (1, *self.shape)
# Raise an exception if shapes aren't compatible. We don't use the output.
jnp.broadcast_shapes(
(target_positions.shape[0], 1, target_positions.shape[1]), shape_with_batch_dim
)
else:
raise NotImplementedError(f"Invalid value {target_positions.ndim=}.")
target_positions = target_positions[..., None] # Shape: [batch, target_len, 1].

return self.mask(target_positions, source_positions) # pylint: disable=not-callable

@classmethod
Expand All @@ -554,26 +556,20 @@ def from_sequence(
return super().from_sequence(biases)
except NotImplementedError:
pass
for bias in biases:
if bias.target_positions is not None:
raise ValueError(f"target_positions was not None for {bias}.")

# Combine masks.
mask = lambda query_position, key_position: jnp.all(
jnp.stack([b.mask(query_position, key_position) for b in biases]), axis=0
)
return MaskFnAttentionBias(
mask=mask,
target_positions=biases[0].target_positions,
source_positions=biases[0].source_positions,
)
return MaskFnAttentionBias(mask=mask, shape=biases[0].shape)

def partition_spec(
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
) -> Union[BaseAttentionBias, PartitionSpec]:
batch = mha_dim_to_partition_spec["bnts"][0]
return dataclasses.replace(
self,
target_positions=PartitionSpec(None if self.target_positions.shape[0] == 1 else batch),
source_positions=PartitionSpec(None if self.source_positions.shape[0] == 1 else batch),
)
return PartitionSpec(*mha_dim_to_partition_spec["bnts"][0:1])


@struct.dataclass
Expand Down Expand Up @@ -644,10 +640,6 @@ class CausalAttentionBias(MaskFnAttentionBias): # pylint: disable=final-error

mask: Optional[MaskFn] = struct.field(pytree_node=False, default=causal_mask)

@classmethod
def default_config(cls) -> ClassConfigBase[MaskFnAttentionBias]:
return config_for_class(CausalAttentionBias)

@classmethod
def from_sequence(
cls, biases: Sequence["CausalAttentionBias"]
Expand All @@ -659,23 +651,6 @@ def from_sequence(
return biases[0]


@struct.dataclass
@final
class SlidingWindowAttentionBias(MaskFnAttentionBias): # pylint: disable=final-error
"""A sliding window attention mask."""

# A left context size for sliding window attention. sliding window size = left context + 1.
left_context: int = struct.field(kw_only=True, pytree_node=False)

@classmethod
# pylint: disable-next=arguments-renamed
def default_config(cls, left_context: int) -> ClassConfigBase[MaskFnAttentionBias]:
return config_for_class(SlidingWindowAttentionBias).set(
mask=sliding_window_causal_mask(left_context=left_context),
left_context=left_context,
)


@struct.dataclass
@final
class ZeroAttentionBias(BoolAttentionBias):
Expand Down Expand Up @@ -722,19 +697,19 @@ def and_masks(*mask_fns: ConfigOr[MaskFn]) -> MaskFn:
return _composite_masks(jnp.logical_and, *mask_fns)


def sliding_window_causal_mask(left_context: int) -> MaskFn:
def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn:
"""Returns a causal MaskFn for sliding window attentions of a given window size.

Implements the `MaskFn` protocol.
"""

def mask(query_position: Tensor, key_position: Tensor):
pos_mask = query_position - key_position <= left_context
# Negative positions indicate prefill padding.
key_valid = key_position >= 0
return pos_mask & key_valid
return query_position - key_position <= sliding_window_size

fun = and_masks(causal_mask, mask)
# Flash attention needs to recognize sliding window size in _to_splash_mask().
# pylint: disable-next=protected-access
fun._sliding_window_size = sliding_window_size
return fun


Expand All @@ -752,17 +727,17 @@ def make_causal_biases(seq_len: int) -> Tensor:
return bool_to_bias(causal_mask(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]))


def make_sliding_window_causal_biases(seq_len: int, left_context: int) -> Tensor:
def make_sliding_window_causal_biases(seq_len: int, sliding_window_size: int) -> Tensor:
"""Generates attention logit biases for sliding window attention.

Args:
seq_len: Sequence length.

Returns:
A float tensor of shape [seq_len, seq_len] where the value at [i, j] = -inf
if i - j > left_context or i < j, 0 otherwise.
if i - j > sliding_window_size or i < j, 0 otherwise.
"""
mask_fn = sliding_window_causal_mask(left_context)
mask_fn = sliding_window_causal_mask(sliding_window_size)
return bool_to_bias(mask_fn(jnp.arange(seq_len)[:, None], jnp.arange(seq_len)[None, :]))


Expand Down
Loading
Loading