Skip to content

Commit

Permalink
Add varying sequence length support to Split KV MHA (fairinternal/xfo…
Browse files Browse the repository at this point in the history
…rmers#736)

* [WIP] Add varying sequence length support to split-k attention

* Add more tests and benchmarks

* Remove max_seq_len hint

Might not be CUDAGraphs-friendly and only really helps for very small contexts

* nit

* Add bound-check to query and output

Was missing before and would get OOB errors in prod

* Fix lint

* Reviewer comments

* Fix lint

__original_commit__ = fairinternal/xformers@f47fa92
  • Loading branch information
fmassa authored and xFormers Bot committed Jul 31, 2023
1 parent f525106 commit 82254f4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
23 changes: 14 additions & 9 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,18 +1630,22 @@ def test_attn_bias_padded() -> None:


@sm80_or_better_only
@pytest.mark.parametrize("op", [fmha.decoder.FwOp])
@pytest.mark.parametrize("multiquery", [True, False], ids=lambda x: "mq" if x else "")
@pytest.mark.parametrize("n_heads", [1, 32])
@pytest.mark.parametrize("n_heads", [1, 16, 32])
@pytest.mark.parametrize("padding", [32, 4096])
@pytest.mark.parametrize("bsz", [1, 8])
@pytest.mark.parametrize("dtype", ["f16", "bf16", "f32"])
def test_decoder(multiquery: bool, n_heads: int, bsz: int, dtype: str) -> None:
def test_decoder(
op, multiquery: bool, n_heads: int, padding: int, bsz: int, dtype: str
) -> None:
dtype_ = {"f16": torch.float16, "bf16": torch.bfloat16, "f32": torch.float32}[dtype]
torch.manual_seed(1)
d, padding = 128, 32
d = 128
k_shape = (1, bsz * padding, n_heads, d)
# TODO: support 2 kv heads etc.
k = torch.randn(k_shape, dtype=dtype_).cuda()
k_seqlen = [5, 8, 7, 1, 9, 3, 12, 32][:bsz]
k_seqlen = torch.randint(1, padding + 1, (bsz,)).tolist()
v = torch.randn(k_shape, dtype=dtype_).cuda()
q = torch.randn((1, bsz, n_heads, d), dtype=dtype_).cuda()
causal_diagonal = torch.tensor( # TODO: make unnecessary
Expand All @@ -1658,18 +1662,19 @@ def test_decoder(multiquery: bool, n_heads: int, bsz: int, dtype: str) -> None:
causal_diagonal=causal_diagonal,
kv_padding=padding,
)
inp = fmha.Inputs(q, k, v, attn_bias=attn_bias)
if not op.supports(inp):
pytest.skip("not supported")

cutlass_output = fmha.memory_efficient_attention_forward(
q, k, v, attn_bias, op=fmha.cutlass.FwOp
)
decoder_output = fmha.memory_efficient_attention_forward(
q, k, v, attn_bias, op=fmha.decoder.FwOp
)
decoder_output = fmha.memory_efficient_attention_forward(q, k, v, attn_bias, op=op)
assert_allclose(
decoder_output,
cutlass_output,
atol=fmha.cutlass.FwOp.ERROR_ATOL[torch.float16] * 4,
rtol=fmha.cutlass.FwOp.ERROR_RTOL[torch.float16],
atol=fmha.cutlass.FwOp.ERROR_ATOL[dtype_],
rtol=fmha.cutlass.FwOp.ERROR_RTOL[dtype_],
)


Expand Down
12 changes: 10 additions & 2 deletions xformers/benchmarks/benchmark_mem_eff_attn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ def T(t):
(32, 1024, 500),
(1000, 1024, 2),
(8000, 8192, 1),
(240, 256, 32),
(2048, 2 * 1024, 4),
(4096 * 2, 8 * 1024, 1),
]

N_HEADS = [8, 64]
N_HEADS = [8, 16, 64]


def product_dict(**kwargs):
Expand All @@ -95,7 +98,8 @@ def mem_eff_attention_decoder(
kv_shape, n_heads: int, num_threads: int, multiquery: bool
):
n_keys, padding, B = kv_shape
k_seqlen = [n_keys] * B
torch.manual_seed(42)
k_seqlen = torch.randint(1, n_keys + 1, (B,)).tolist()
K = 128

q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16)
Expand All @@ -122,6 +126,10 @@ def mem_eff_attention_decoder(

has_run = False
for fw_op in OPS:
inp = fmha.Inputs(q, k, v, attn_bias=bias)
if not fw_op.supports(inp):
continue

fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op)

yield benchmark.Timer(
Expand Down

0 comments on commit 82254f4

Please sign in to comment.