Skip to content

Commit

Permalink
[fmha] Added LocalAttentionFromBottomRightMask
Browse files Browse the repository at this point in the history
ghstack-source-id: fad4f87f959883a7ffe04cb26b6b3de79d59f90e
Pull Request resolved: https://github.com/fairinternal/xformers/pull/952

__original_commit__ = fairinternal/xformers@e53571969e12212b5af5a18a71fb154673a06399
  • Loading branch information
xFormers Bot committed Dec 6, 2023
1 parent 238341a commit b42af03
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- fMHA: Fixed a bug in cutlass backend forward pass where the logsumexp was not correctly calculated, resulting in wrong results in the BW pass. This would happen with MQA when one sequence has a query with `length%64 == 1`
### Added
- fMHA: Added `LocalAttentionFromBottomRightMask` (local)
- fMHA: Added `LowerTriangularFromBottomRightMask` (causal)
- fMHA: Added `LowerTriangularFromBottomRightLocalAttentionMask` (local + causal)
### Removed
- Removed `xformers.triton.sum_strided`

Expand Down
Binary file added docs/source/_static/local_attn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def create_tensors(
fmha.attn_bias.BlockDiagonalCausalFromBottomRightMask,
fmha.attn_bias.BlockDiagonalCausalLocalAttentionFromBottomRightMask,
fmha.attn_bias.BlockDiagonalCausalLocalAttentionMask,
fmha.attn_bias.LocalAttentionFromBottomRightMask,
),
)
if mask_is_bottom_right and q_len > kv_len:
Expand Down Expand Up @@ -2163,4 +2164,17 @@ def test_empty_tensors_empty_b(
out.backward(out)


def test_local_attn_bias() -> None:
mask = (
fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2)
.materialize(shape=(4, 4))
.exp()
)

expected = torch.tensor(
[[1, 1, 1, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 0, 1, 1]], dtype=torch.float32
)
assert (mask == expected).all().item()


# end of file
5 changes: 5 additions & 0 deletions xformers/attn_bias_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def create_attn_bias(
)
)
return g_block_diag
if bias_type == fmha.attn_bias.LocalAttentionFromBottomRightMask:
return bias_type(
window_left=r.randint(0, 5),
window_right=r.randint(0, 5),
)

assert False, f"Unsupported bias type: {bias_type}"

Expand Down
83 changes: 83 additions & 0 deletions xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,89 @@ def _materialize_causal_mask(
return mask.to(dtype)


@dataclass
class LocalAttentionFromBottomRightMask(AttentionBias):
"""
A local attention mask
The query at position :math:`q` can attend the key at position :math:`k` if
:math:`q - window\\_left <= k + s <= q + window\\_right`
With :math:`s = num\\_queries - num\\_keys`
:Example:
.. code-block:: python
import torch
from xformers.ops import fmha
bias = fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2)
print(bias.materialize(shape=(4, 4)).exp())
print(bias.materialize(shape=(4, 5)).exp())
.. code-block:: text
# 4x4
tensor([[1., 1., 1., 0.],
[1., 1., 1., 1.],
[0., 1., 1., 1.],
[0., 0., 1., 1.]])
# 4x5
tensor([[1., 1., 1., 1., 0.],
[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.]])
:Illustration:
.. figure:: /_static/local_attn.png
:width: 240px
The total window size is :math:`window\\_left + 1 + window\\_right`
"""

window_left: int
window_right: int

def __post_init__(self) -> None:
if self.window_left < 0:
raise ValueError(
"Invalid window value passed to "
"`LocalAttentionFromBottomRightMask`: expected"
f"`window_left > 0` but got window_left={self.window_left}"
)
if self.window_right < 0:
raise ValueError(
"Invalid window value passed to "
"`LocalAttentionFromBottomRightMask`: expected"
f"`window_right > 0` but got window_right={self.window_right}"
)

def materialize(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
create_as = dtype if dtype is not torch.bfloat16 else torch.float32
mask = torch.full( # type: ignore
shape,
dtype=create_as,
fill_value=1,
device=device,
)

num_queries, num_keys = shape[-2:]
shift = num_keys - num_queries

mask = torch.triu(mask, diagonal=shift - self.window_left)
mask = torch.tril(mask, diagonal=shift + self.window_right)
mask = torch.log(mask)
return mask.to(dtype)


class LowerTriangularMask(AttentionBias):
"""
A lower-triangular (aka causal) mask
Expand Down
65 changes: 41 additions & 24 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalMask,
LocalAttentionFromBottomRightMask,
LowerTriangularFromBottomRightLocalAttentionMask,
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
Expand Down Expand Up @@ -61,15 +62,17 @@
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, "
"bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
"bool is_causal, int window_left, "
"int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
)

_flash_lib.define(
"flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
"Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
"float p, float softmax_scale, bool is_causal, "
"int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
)

def _flash_fwd(
Expand All @@ -84,7 +87,8 @@ def _flash_fwd(
p,
softmax_scale,
is_causal,
window_size,
window_left,
window_right,
return_softmax,
):
if cu_seq_lens_q is None:
Expand All @@ -107,8 +111,8 @@ def _flash_fwd(
p,
softmax_scale,
is_causal,
window_size - 1, # window_size_left
-1, # window_size_right
window_left, # window_size_left
window_right, # window_size_right
return_softmax,
None, # rng
)
Expand Down Expand Up @@ -137,8 +141,8 @@ def _flash_fwd(
softmax_scale,
False,
is_causal,
window_size - 1, # window_size_left
-1, # window_size_right
window_left,
window_right,
return_softmax,
None,
)
Expand All @@ -161,7 +165,8 @@ def _flash_bwd(
p,
softmax_scale,
is_causal,
window_size,
window_left,
window_right,
rng_state,
):
if cu_seq_lens_k is None:
Expand All @@ -179,8 +184,8 @@ def _flash_bwd(
p,
softmax_scale,
is_causal,
window_size - 1, # window_size_left
-1, # window_size_right
window_left,
window_right,
None,
rng_state,
)
Expand All @@ -203,8 +208,8 @@ def _flash_bwd(
softmax_scale,
False, # zero_tensors
is_causal,
window_size - 1, # window_size_left
-1, # window_size_right
window_left,
window_right,
None,
rng_state,
)
Expand Down Expand Up @@ -328,17 +333,24 @@ def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
)


def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int:
def _window_size(
attn_bias: Optional[Union[torch.Tensor, AttentionBias]]
) -> Tuple[int, int]:
win_left = -1
win_right = -1
if isinstance(
attn_bias,
(BlockDiagonalCausalLocalAttentionMask,),
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
LowerTriangularFromBottomRightLocalAttentionMask,
),
):
return attn_bias._window_size or 0
if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask):
return attn_bias._window_size
if isinstance(attn_bias, LowerTriangularFromBottomRightLocalAttentionMask):
return attn_bias._window_size
return 0
win_left = attn_bias._window_size - 1
if isinstance(attn_bias, LocalAttentionFromBottomRightMask):
win_left = attn_bias.window_left
win_right = attn_bias.window_right
return (win_left, win_right)


def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None:
Expand Down Expand Up @@ -404,6 +416,7 @@ class FwOp(AttentionFwOpBase):
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
LocalAttentionFromBottomRightMask,
}
SUPPORTS_DROPOUT = True
SUPPORTS_CUSTOM_SCALE = True
Expand Down Expand Up @@ -441,6 +454,7 @@ def apply(
seqused_k,
) = _convert_input_format(inp, supports_mqa=True)
if inp.query.numel() > 0 and inp.key.numel() > 0:
win_left, win_right = _window_size(inp.attn_bias)
out, softmax_lse, rng_state = cls.OPERATOR(
inp.query,
inp.key,
Expand All @@ -453,8 +467,9 @@ def apply(
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
return_softmax,
window_left=win_left,
window_right=win_right,
return_softmax=return_softmax,
)
out = out.reshape(out_shape)
else:
Expand Down Expand Up @@ -612,6 +627,7 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
if grads.dv.numel() == 0:
grads.dq.zero_()
if grads.dq.numel() and grads.dk.numel():
win_left, win_right = _window_size(inp.attn_bias)
cls.OPERATOR(
grad.reshape(kernel_out_shape).contiguous(),
inp.query,
Expand All @@ -629,8 +645,9 @@ def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
inp.p,
inp.scale_float,
_is_causal(inp.attn_bias),
_window_size(inp.attn_bias),
ctx.rng_state,
window_left=win_left,
window_right=win_right,
rng_state=ctx.rng_state,
)
grads.dq = grads.dq.reshape(dq_shape)
grads.dk = grads.dk.reshape(dk_shape)
Expand Down

0 comments on commit b42af03

Please sign in to comment.