diff --git a/CHANGELOG.md b/CHANGELOG.md index d60823b9dd..324951da8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - fMHA: Fixed a bug in cutlass backend forward pass where the logsumexp was not correctly calculated, resulting in wrong results in the BW pass. This would happen with MQA when one sequence has a query with `length%64 == 1` ### Added +- fMHA: Added `LocalAttentionFromBottomRightMask` (local) +- fMHA: Added `LowerTriangularFromBottomRightMask` (causal) +- fMHA: Added `LowerTriangularFromBottomRightLocalAttentionMask` (local + causal) ### Removed - Removed `xformers.triton.sum_strided` diff --git a/docs/source/_static/local_attn.png b/docs/source/_static/local_attn.png new file mode 100644 index 0000000000..b6e58c01ce Binary files /dev/null and b/docs/source/_static/local_attn.png differ diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 2cea3395cc..6fc5afc34a 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -352,6 +352,7 @@ def create_tensors( fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask, fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask, fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask, + fmha.attn_bias.LocalAttentionFromBottomRightMask, ), ) if mask_is_bottom_right and q_len > kv_len: @@ -2163,4 +2164,17 @@ def test_empty_tensors_empty_b( out.backward(out) +def test_local_attn_bias() -> None: + mask = ( + fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2) + .materialize(shape=(4, 4)) + .exp() + ) + + expected = torch.tensor( + [[1, 1, 1, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1]], dtype=torch.float32 + ) + assert (mask == expected).all().item() + + # end of file diff --git a/xformers/attn_bias_utils.py b/xformers/attn_bias_utils.py index 2e4ee8435d..7bdbd5706c 100644 --- a/xformers/attn_bias_utils.py +++ b/xformers/attn_bias_utils.py @@ -161,6 +161,11 @@ def create_attn_bias( ) ) return g_block_diag + if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask: + return bias_type( + window_left=r.randint(0, 5), + window_right=r.randint(0, 5), + ) assert False, f"Unsupported bias type: {bias_type}" diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index 7db7bd1c39..78044f7db5 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -88,6 +88,89 @@ def _materialize_causal_mask( return mask.to(dtype) +@dataclass +class LocalAttentionFromBottomRightMask(AttentionBias): + """ + A local attention mask + + The query at position :math:`q` can attend the key at position :math:`k` if + :math:`q - window\\_left <= k + s <= q + window\\_right` + + With :math:`s = num\\_queries - num\\_keys` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + bias = fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2) + print(bias.materialize(shape=(4, 4)).exp()) + print(bias.materialize(shape=(4, 5)).exp()) + + .. code-block:: text + + # 4x4 + tensor([[1., 1., 1., 0.], + [1., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 0., 1., 1.]]) + + # 4x5 + tensor([[1., 1., 1., 1., 0.], + [0., 1., 1., 1., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1.]]) + + :Illustration: + + .. figure:: /_static/local_attn.png + :width: 240px + + The total window size is :math:`window\\_left + 1 + window\\_right` + """ + + window_left: int + window_right: int + + def __post_init__(self) -> None: + if self.window_left < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_left > 0` but got window_left={self.window_left}" + ) + if self.window_right < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_right > 0` but got window_right={self.window_right}" + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + mask = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + shift = num_keys - num_queries + + mask = torch.triu(mask, diagonal=shift - self.window_left) + mask = torch.tril(mask, diagonal=shift + self.window_right) + mask = torch.log(mask) + return mask.to(dtype) + + class LowerTriangularMask(AttentionBias): """ A lower-triangular (aka causal) mask diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 668c53715a..f2806f8a35 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -20,6 +20,7 @@ BlockDiagonalCausalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask, BlockDiagonalMask, + LocalAttentionFromBottomRightMask, LowerTriangularFromBottomRightLocalAttentionMask, LowerTriangularFromBottomRightMask, LowerTriangularMask, @@ -61,7 +62,8 @@ "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " "int max_seqlen_q, int max_seqlen_k, " "float p, float softmax_scale, " - "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" + "bool is_causal, int window_left, " + "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" ) _flash_lib.define( @@ -69,7 +71,8 @@ "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " "int max_seqlen_q, int max_seqlen_k, " - "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" + "float p, float softmax_scale, bool is_causal, " + "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" ) def _flash_fwd( @@ -84,7 +87,8 @@ def _flash_fwd( p, softmax_scale, is_causal, - window_size, + window_left, + window_right, return_softmax, ): if cu_seq_lens_q is None: @@ -107,8 +111,8 @@ def _flash_fwd( p, softmax_scale, is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + window_left, # window_size_left + window_right, # window_size_right return_softmax, None, # rng ) @@ -137,8 +141,8 @@ def _flash_fwd( softmax_scale, False, is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + window_left, + window_right, return_softmax, None, ) @@ -161,7 +165,8 @@ def _flash_bwd( p, softmax_scale, is_causal, - window_size, + window_left, + window_right, rng_state, ): if cu_seq_lens_k is None: @@ -179,8 +184,8 @@ def _flash_bwd( p, softmax_scale, is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + window_left, + window_right, None, rng_state, ) @@ -203,8 +208,8 @@ def _flash_bwd( softmax_scale, False, # zero_tensors is_causal, - window_size - 1, # window_size_left - -1, # window_size_right + window_left, + window_right, None, rng_state, ) @@ -328,17 +333,24 @@ def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool: ) -def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: +def _window_size( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> Tuple[int, int]: + win_left = -1 + win_right = -1 if isinstance( attn_bias, - (BlockDiagonalCausalLocalAttentionMask,), + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + LowerTriangularFromBottomRightLocalAttentionMask, + ), ): - return attn_bias._window_size or 0 - if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask): - return attn_bias._window_size - if isinstance(attn_bias, LowerTriangularFromBottomRightLocalAttentionMask): - return attn_bias._window_size - return 0 + win_left = attn_bias._window_size - 1 + if isinstance(attn_bias, LocalAttentionFromBottomRightMask): + win_left = attn_bias.window_left + win_right = attn_bias.window_right + return (win_left, win_right) def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None: @@ -404,6 +416,7 @@ class FwOp(AttentionFwOpBase): BlockDiagonalCausalLocalAttentionFromBottomRightMask, BlockDiagonalCausalFromBottomRightMask, BlockDiagonalCausalWithOffsetPaddedKeysMask, + LocalAttentionFromBottomRightMask, } SUPPORTS_DROPOUT = True SUPPORTS_CUSTOM_SCALE = True @@ -441,6 +454,7 @@ def apply( seqused_k, ) = _convert_input_format(inp, supports_mqa=True) if inp.query.numel() > 0 and inp.key.numel() > 0: + win_left, win_right = _window_size(inp.attn_bias) out, softmax_lse, rng_state = cls.OPERATOR( inp.query, inp.key, @@ -453,8 +467,9 @@ def apply( inp.p, inp.scale_float, _is_causal(inp.attn_bias), - _window_size(inp.attn_bias), - return_softmax, + window_left=win_left, + window_right=win_right, + return_softmax=return_softmax, ) out = out.reshape(out_shape) else: @@ -612,6 +627,7 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: if grads.dv.numel() == 0: grads.dq.zero_() if grads.dq.numel() and grads.dk.numel(): + win_left, win_right = _window_size(inp.attn_bias) cls.OPERATOR( grad.reshape(kernel_out_shape).contiguous(), inp.query, @@ -629,8 +645,9 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: inp.p, inp.scale_float, _is_causal(inp.attn_bias), - _window_size(inp.attn_bias), - ctx.rng_state, + window_left=win_left, + window_right=win_right, + rng_state=ctx.rng_state, ) grads.dq = grads.dq.reshape(dq_shape) grads.dk = grads.dk.reshape(dk_shape)