diff --git a/tests/ops/test_flash_attn.py b/tests/ops/test_flash_attn.py index 68451c9..af8b20e 100644 --- a/tests/ops/test_flash_attn.py +++ b/tests/ops/test_flash_attn.py @@ -59,8 +59,8 @@ def setup_env(): ], ) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, - local, alibi, deterministic, mha_type, dtype): +def test_flash_attn_output(seqlen_q, seqlen_k, d, dropout_p, causal, local, + alibi, deterministic, mha_type, dtype): if d % 8 != 0: pytest.skip(reason="Expected head_size_og % 8 == 0 to be true") # TODO(to wenting.swt): fix the correctness issue, refer to FIXME diff --git a/tests/ops/test_flash_attn_varlen.py b/tests/ops/test_flash_attn_varlen.py index 3eea960..8f0cc1d 100644 --- a/tests/ops/test_flash_attn_varlen.py +++ b/tests/ops/test_flash_attn_varlen.py @@ -11,17 +11,20 @@ from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) + class FlashAttention2(nn.Module): def __init__(self, hidden_size, num_attention_heads, num_key_value_heads): @@ -32,17 +35,21 @@ def __init__(self, hidden_size, num_attention_heads, num_key_value_heads): self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): - + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) + query_states, key_states, value_states, attention_mask, + query_length) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens @@ -58,28 +65,37 @@ def _flash_attention_forward( causal=True, ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) # re fill the masked with 0.f + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, + query_length) # re fill the masked with 0.f else: attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True - ) + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=True) return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape # b, s, h, d + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape # b, s, h, d key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k # filter out the key with unmask query + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), + indices_k # filter out the key with unmask query ) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, + head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k @@ -93,28 +109,34 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask) return ( - query_layer, # (b*s, h, d), b*s is the true data - key_layer, # (b*s, h, d) - value_layer, # (b*s, h, d) + query_layer, # (b*s, h, d), b*s is the true data + key_layer, # (b*s, h, d) + value_layer, # (b*s, h, d) indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) - + def forward(self, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _, _ = query_states.size() attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=0.0 - ) + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0.0) return attn_output + class FlashAttentionXla(nn.Module): def __init__(self, hidden_size, num_attention_heads, num_key_value_heads): @@ -125,27 +147,44 @@ def __init__(self, hidden_size, num_attention_heads, num_key_value_heads): self.num_key_value_heads = num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): + def _flash_attention_forward(self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None): # Contains at least one padding token in the sequence if attention_mask is None: - attn_output = ta.ops.flash_attn_xla(query_states, key_states, value_states, dropout_p=dropout, causal=True) # re fill the masked with 0.f + attn_output = ta.ops.flash_attn_xla( + query_states, + key_states, + value_states, + dropout_p=dropout, + causal=True) # re fill the masked with 0.f else: attn_output = ta.ops.flash_attn_varlen_xla( - query_states, key_states, value_states, attention_mask=attention_mask, dropout_p=dropout, causal=True) + query_states, + key_states, + value_states, + attention_mask=attention_mask, + dropout_p=dropout, + causal=True) return attn_output - def forward(self, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _, _ = query_states.size() attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=0.0 - ) + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0.0) return attn_output @@ -153,7 +192,7 @@ def forward(self, query_states: torch.Tensor, key_states: torch.Tensor, value_st @pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("d", [128]) -@pytest.mark.parametrize( "seqlen", [2048]) +@pytest.mark.parametrize("seqlen", [2048]) def test_flash_attn_varlen(seqlen, d, dtype, mha_type): batch_size = 4 @@ -162,35 +201,14 @@ def test_flash_attn_varlen(seqlen, d, dtype, mha_type): torch.manual_seed(0) device = "cuda" - q = torch.randn( - batch_size, - seqlen, - nheads, - d, - device=device, - dtype=dtype) - k = torch.randn( - batch_size, - seqlen, - nheads_k, - d, - device=device, - dtype=dtype) - v = torch.randn( - batch_size, - seqlen, - nheads_k, - d, - device=device, - dtype=dtype) + q = torch.randn(batch_size, seqlen, nheads, d, device=device, dtype=dtype) + k = torch.randn(batch_size, seqlen, nheads_k, d, device=device, dtype=dtype) + v = torch.randn(batch_size, seqlen, nheads_k, d, device=device, dtype=dtype) g = torch.randn_like(q) attention_mask = torch.zeros( - batch_size, - seqlen, - dtype=torch.int32).to(device) + batch_size, seqlen, dtype=torch.int32).to(device) k_lengths = torch.randint(low=2, high=seqlen, size=(batch_size,)) - print(f'k_lengths={k_lengths}') for i in range(batch_size): k_len = k_lengths[i].item() @@ -240,7 +258,22 @@ def test_flash_attn_varlen(seqlen, d, dtype, mha_type): dv_xla, ) = torch.autograd.grad(ret_xla, (q_xla, k_xla, v_xla), g_xla) ta.mark_step() - - assert torch.allclose(dq_xla.cpu().detach(), dq.cpu().detach(), rtol=1e-2, atol=1e-2, equal_nan=True) - assert torch.allclose(dk_xla.cpu().detach(), dk.cpu().detach(), rtol=1e-2, atol=1e-2, equal_nan=True) - assert torch.allclose(dv_xla.cpu().detach(), dv.cpu().detach(), rtol=1e-2, atol=1e-2, equal_nan=True) \ No newline at end of file + + assert torch.allclose( + dq_xla.cpu().detach(), + dq.cpu().detach(), + rtol=1e-2, + atol=1e-2, + equal_nan=True) + assert torch.allclose( + dk_xla.cpu().detach(), + dk.cpu().detach(), + rtol=1e-2, + atol=1e-2, + equal_nan=True) + assert torch.allclose( + dv_xla.cpu().detach(), + dv.cpu().detach(), + rtol=1e-2, + atol=1e-2, + equal_nan=True) diff --git a/torchacc/ops/flash_attn.py b/torchacc/ops/flash_attn.py index 6524260..8b3e481 100644 --- a/torchacc/ops/flash_attn.py +++ b/torchacc/ops/flash_attn.py @@ -165,8 +165,8 @@ def backward(ctx, dout, *args): class FlashAttnVarlenXla(torch.autograd.Function): @staticmethod - def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, window_size, - alibi_slopes, deterministic, return_softmax): + def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, + window_size, alibi_slopes, deterministic, return_softmax): if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) assert isinstance(window_size, tuple) and len(window_size) == 2 @@ -175,8 +175,8 @@ def forward(ctx, q, k, v, attention_mask, dropout_p, softmax_scale, causal, wind q, k, v = [maybe_contiguous(x) for x in (q, k, v)] softmax_lse, out, rng_state, cu_seqlens_q, cu_seqlens_k = torch_xla._XLAC._flash_attention_forward( - q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale, False, causal, - window_size[0], window_size[1], return_softmax, None) + q, k, v, attention_mask, alibi_slopes, dropout_p, softmax_scale, + False, causal, window_size[0], window_size[1], return_softmax, None) out = out.to(q.dtype) ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, @@ -197,9 +197,9 @@ def backward(ctx, dout, *args): dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward( dout, q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, - ctx.alibi_slopes, ctx.dropout_p, - ctx.softmax_scale, False, ctx.causal, ctx.window_size[0], - ctx.window_size[1], ctx.deterministic, None, rng_state) + ctx.alibi_slopes, ctx.dropout_p, ctx.softmax_scale, False, + ctx.causal, ctx.window_size[0], ctx.window_size[1], + ctx.deterministic, None, rng_state) dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dk = dk[..., :dout.shape[-1]] @@ -220,9 +220,8 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, bsz, q_len, head_size, _ = q.size() softmax_lse, out, rng_state = torch_xla._XLAC._flash_attention_forward( - q, k, v, None, alibi_slopes, - dropout_p, softmax_scale, False, causal, window_size[0], - window_size[1], return_softmax, None) + q, k, v, None, alibi_slopes, dropout_p, softmax_scale, False, + causal, window_size[0], window_size[1], return_softmax, None) out = out.to(q.dtype) ctx.save_for_backward(q, k, v, out, softmax_lse, rng_state) @@ -240,10 +239,10 @@ def backward(ctx, dout, *args): q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors dq, dk, dv, softmax_d = torch_xla._XLAC._flash_attention_backward( - dout, q, k, v, out, softmax_lse, None, None, - ctx.alibi_slopes, ctx.dropout_p, - ctx.softmax_scale, False, ctx.causal, ctx.window_size[0], - ctx.window_size[1], ctx.deterministic, None, rng_state) + dout, q, k, v, out, softmax_lse, None, None, ctx.alibi_slopes, + ctx.dropout_p, ctx.softmax_scale, False, ctx.causal, + ctx.window_size[0], ctx.window_size[1], ctx.deterministic, None, + rng_state) dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension dk = dk[..., :dout.shape[-1]] diff --git a/torchacc/utils/patch.py b/torchacc/utils/patch.py index cd8a7ce..6ee3451 100644 --- a/torchacc/utils/patch.py +++ b/torchacc/utils/patch.py @@ -6,6 +6,7 @@ import torchacc.ops as ops from torchacc.core import amp +from torchacc.utils.logger import logger def _patch_functions(fn, newfn): @@ -55,20 +56,56 @@ def patch_amp(): def patch_fa(): ''' - Replace `flash_attn.flash_attn_func`, `flash_attn.flash_attn_varlen_func` with - `torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla`, - and dynamically determine which one to use at runtime based on the input device. + Replace `transformers.modeling_flash_attention_utils._flash_attention_forward` with + `torchacc.ops.flash_attn_xla` and `torchacc.ops.flash_attn_varlen_xla` ''' - from transformers.models.llama.modeling_llama import LlamaFlashAttention2 - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): - if attention_mask is not None: - return ops.flash_attn_varlen_xla(query_states.contiguous(), key_states.contiguous(), value_states.contiguous(), attention_mask=attention_mask.contiguous(), dropout_p=dropout, softmax_scale=softmax_scale) + try: + import transformers + version = transformers.__version__ + if version >= "4.44.2": + import transformers.modeling_flash_attention_utils as modeling_flash_attention_utils + from typing import Optional + + def _flash_attention_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + dropout: float = 0.0, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + sliding_window: Optional[int] = None, + use_top_left_mask: bool = False, + softcap: Optional[float] = None, + deterministic: bool = None, + ): + if attention_mask is not None: + return ops.flash_attn_varlen_xla( + query_states.contiguous(), + key_states.contiguous(), + value_states.contiguous(), + attention_mask=attention_mask.contiguous(), + dropout_p=dropout, + softmax_scale=softmax_scale) + else: + return ops.flash_attn_xla( + query_states, + key_states, + value_states, + dropout_p=dropout, + softmax_scale=softmax_scale) + + modeling_flash_attention_utils._flash_attention_forward = _flash_attention_forward else: - return ops.flash_attn_xla(query_states, key_states, value_states, dropout_p=dropout, softmax_scale=softmax_scale) - - LlamaFlashAttention2._flash_attention_forward = _flash_attention_forward + logger.warn( + f'FlashAttention is not successfully patched with transformers version={version}' + ) + except: + logger.warn( + 'transformers is not installed, torchacc will not patch FlashAttention for transformers' + ) def patch_llama(use_flash_attn):