Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
Seventeen17 committed Sep 19, 2024
1 parent 0eae652 commit 2f37c34
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 92 deletions.
4 changes: 2 additions & 2 deletions tests/ops/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def setup_env():
],
)
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal,
local, alibi, deterministic, mha_type, dtype):
def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, local,
alibi, deterministic, mha_type, dtype):
if d % 8 != 0:
pytest.skip(reason="Expected head_size_og % 8 == 0 to be true")
# TODO(to wenting.swt): fix the correctness issue, refer to FIXME
Expand Down
161 changes: 97 additions & 64 deletions tests/ops/test_flash_attn_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens = F.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


class FlashAttention2(nn.Module):

def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
Expand All @@ -32,17 +35,21 @@ def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads

def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):

def _flash_attention_forward(self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None):

# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
query_states, key_states, value_states, attention_mask,
query_length)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
Expand All @@ -58,28 +65,37 @@ def _flash_attention_forward(
causal=True,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) # re fill the masked with 0.f
attn_output = pad_input(attn_output_unpad, indices_q, batch_size,
query_length) # re fill the masked with 0.f
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=True)

return attn_output

def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape # b, s, h, d
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask,
query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape # b, s, h, d

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k # filter out the key with unmask query
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
head_dim),
indices_k # filter out the key with unmask query
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
query_layer.reshape(batch_size * kv_seq_len, self.num_heads,
head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
Expand All @@ -93,28 +109,34 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
query_layer, attention_mask)

return (
query_layer, # (b*s, h, d), b*s is the true data
key_layer, # (b*s, h, d)
value_layer, # (b*s, h, d)
query_layer, # (b*s, h, d), b*s is the true data
key_layer, # (b*s, h, d)
value_layer, # (b*s, h, d)
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)

def forward(self, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _, _ = query_states.size()

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=0.0
)
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0.0)

return attn_output


class FlashAttentionXla(nn.Module):

def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
Expand All @@ -125,35 +147,52 @@ def __init__(self, hidden_size, num_attention_heads, num_key_value_heads):
self.num_key_value_heads = num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads


def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
def _flash_attention_forward(self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None):

# Contains at least one padding token in the sequence
if attention_mask is None:
attn_output = ta.ops.flash_attn_xla(query_states, key_states, value_states, dropout_p=dropout, causal=True) # re fill the masked with 0.f
attn_output = ta.ops.flash_attn_xla(
query_states,
key_states,
value_states,
dropout_p=dropout,
causal=True) # re fill the masked with 0.f
else:
attn_output = ta.ops.flash_attn_varlen_xla(
query_states, key_states, value_states, attention_mask=attention_mask, dropout_p=dropout, causal=True)
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout_p=dropout,
causal=True)
return attn_output

def forward(self, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None) -> \
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _, _ = query_states.size()

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=0.0
)
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=0.0)

return attn_output


@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mha_type", ["mha"])
@pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize( "seqlen", [2048])
@pytest.mark.parametrize("seqlen", [2048])
def test_flash_attn_varlen(seqlen, d, dtype, mha_type):

batch_size = 4
Expand All @@ -162,35 +201,14 @@ def test_flash_attn_varlen(seqlen, d, dtype, mha_type):

torch.manual_seed(0)
device = "cuda"
q = torch.randn(
batch_size,
seqlen,
nheads,
d,
device=device,
dtype=dtype)
k = torch.randn(
batch_size,
seqlen,
nheads_k,
d,
device=device,
dtype=dtype)
v = torch.randn(
batch_size,
seqlen,
nheads_k,
d,
device=device,
dtype=dtype)
q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype)
k = torch.randn(batch_size, seqlen, nheads_k, d, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen, nheads_k, d, device=device, dtype=dtype)
g = torch.randn_like(q)
attention_mask = torch.zeros(
batch_size,
seqlen,
dtype=torch.int32).to(device)
batch_size, seqlen, dtype=torch.int32).to(device)

k_lengths = torch.randint(low=2, high=seqlen, size=(batch_size,))
print(f'k_lengths={k_lengths}')

for i in range(batch_size):
k_len = k_lengths[i].item()
Expand Down Expand Up @@ -240,7 +258,22 @@ def test_flash_attn_varlen(seqlen, d, dtype, mha_type):
dv_xla,
) = torch.autograd.grad(ret_xla, (q_xla, k_xla, v_xla), g_xla)
ta.mark_step()

assert torch.allclose(dq_xla.cpu().detach(), dq.cpu().detach(), rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dk_xla.cpu().detach(), dk.cpu().detach(), rtol=1e-2, atol=1e-2, equal_nan=True)
assert torch.allclose(dv_xla.cpu().detach(), dv.cpu().detach(), rtol=1e-2, atol=1e-2, equal_nan=True)

assert torch.allclose(
dq_xla.cpu().detach(),
dq.cpu().detach(),
rtol=1e-2,
atol=1e-2,
equal_nan=True)
assert torch.allclose(
dk_xla.cpu().detach(),
dk.cpu().detach(),
rtol=1e-2,
atol=1e-2,
equal_nan=True)
assert torch.allclose(
dv_xla.cpu().detach(),
dv.cpu().detach(),
rtol=1e-2,
atol=1e-2,
equal_nan=True)
27 changes: 13 additions & 14 deletions torchacc/ops/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def backward(ctx, dout, *args):
class FlashAttnVarlenXla(torch.autograd.Function):

@staticmethod
def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, window_size,
alibi_slopes, deterministic, return_softmax):
def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal,
window_size, alibi_slopes, deterministic, return_softmax):
if softmax_scale is None:
softmax_scale = q.shape[-1]**(-0.5)
assert isinstance(window_size, tuple) and len(window_size) == 2
Expand All @@ -175,8 +175,8 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]

softmax_lse, out, rng_state, cu_seqlens_q, cu_seqlens_k = torch_xla._XLAC._flash_attention_forward(
q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale, False, causal,
window_size[0], window_size[1], return_softmax, None)
q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale,
False, causal, window_size[0], window_size[1], return_softmax, None)
out = out.to(q.dtype)

ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q,
Expand All @@ -197,9 +197,9 @@ def backward(ctx, dout, *args):
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward(
dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k,
ctx.alibi_slopes, ctx.dropout_p,
ctx.softmax_scale, False, ctx.causal, ctx.window_size[0],
ctx.window_size[1], ctx.deterministic, None, rng_state)
ctx.alibi_slopes, ctx.dropout_p, ctx.softmax_scale, False,
ctx.causal, ctx.window_size[0], ctx.window_size[1],
ctx.deterministic, None, rng_state)

dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
Expand All @@ -220,9 +220,8 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size,
bsz, q_len, head_size, _ = q.size()

softmax_lse, out, rng_state = torch_xla._XLAC._flash_attention_forward(
q, k, v, None, alibi_slopes,
dropout_p, softmax_scale, False, causal, window_size[0],
window_size[1], return_softmax, None)
q, k, v, None, alibi_slopes, dropout_p, softmax_scale, False,
causal, window_size[0], window_size[1], return_softmax, None)
out = out.to(q.dtype)

ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state)
Expand All @@ -240,10 +239,10 @@ def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors

dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward(
dout, q, k, v, out, softmax_lse, None, None,
ctx.alibi_slopes, ctx.dropout_p,
ctx.softmax_scale, False, ctx.causal, ctx.window_size[0],
ctx.window_size[1], ctx.deterministic, None, rng_state)
dout, q, k, v, out, softmax_lse, None, None, ctx.alibi_slopes,
ctx.dropout_p, ctx.softmax_scale, False, ctx.causal,
ctx.window_size[0], ctx.window_size[1], ctx.deterministic, None,
rng_state)

dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension
dk = dk[..., :dout.shape[-1]]
Expand Down
Loading

0 comments on commit 2f37c34

Please sign in to comment.