-
Notifications
You must be signed in to change notification settings - Fork 27.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor flash attention implementation in transformers (#31446)
* dumb commit * nit * update * something like this * unpack in modeling utils * safe import * oups * update * nits * diff convert gemma * update * start propagating * udpate other modeling code as well * update for sliding window models * nits * more init cleanups * styling * fixup * noice * pass fixup * typo typing_extension -> typing_extensions * torch.nn.functionnal -> torch.nn.functional * add to import structure * unpack * simplify a bit more for this first version * nut * update * update * nit * ease the import of `Unpack` * remove useless `use_sliding_window` * no qua please * protect import? * style * [run-slow] * [run slow] llama,gemma,mistral,mixtral * remove extra kwargs * fix llama * address review comments * apply diff_model_converter to modeling_gemma.py * remove cache_position 1 * remove cache_position 2 * some cleaning * refactor gemma2 as well * apply review comments * rename file to modeling_flash_attention_utils.py * siglip refactor * remove dead code * is the hub down? * still down? * fix siglip * fix gemma2 * fatal: Could not read from remote repository. * fix typo in softcap implem * flacky * Failed: Timeout >120.0s --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
- Loading branch information
1 parent
ad4ef3a
commit e314395
Showing
49 changed files
with
792 additions
and
5,365 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# coding=utf-8 | ||
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import inspect | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from .utils import is_flash_attn_2_available | ||
|
||
|
||
if is_flash_attn_2_available(): | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
|
||
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) | ||
|
||
|
||
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: | ||
""" | ||
Retrieves indexing data required to repad unpadded (ragged) tensors. | ||
Arguments: | ||
attention_mask (`torch.Tensor`): | ||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. | ||
Return: | ||
indices (`torch.Tensor): | ||
The indices of non-masked tokens from the flattened input sequence. | ||
cu_seqlens (`torch.Tensor`): | ||
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). | ||
max_seqlen_in_batch (`int`): | ||
Maximum sequence length in batch. | ||
""" | ||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
return ( | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
def _upad_input( | ||
query_layer: torch.Tensor, | ||
key_layer: torch.Tensor, | ||
value_layer: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
query_length: int, | ||
): | ||
""" | ||
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. | ||
This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary | ||
tensors for query, key, value tensors. | ||
Arguments: | ||
query_layer (`torch.Tensor`): | ||
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). | ||
key_layer (`torch.Tensor`): | ||
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). | ||
value_layer (`torch.Tensor`): | ||
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). | ||
attention_mask (`torch.Tensor`): | ||
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. | ||
query_length (`int`): | ||
Target length. | ||
Return: | ||
query_layer (`torch.Tensor): | ||
Query state without padding. Shape: (total_target_length, num_heads, head_dim). | ||
key_layer (`torch.Tensor`): | ||
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). | ||
value_layer (`torch.Tensor`): | ||
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). | ||
indices_q (`torch.Tensor`): | ||
The indices of non-masked tokens from the flattened input target sequence. | ||
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): | ||
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): | ||
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). | ||
""" | ||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | ||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | ||
|
||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) | ||
value_layer = index_first_axis( | ||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | ||
) | ||
if query_length == kv_seq_len: | ||
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) | ||
cu_seqlens_q = cu_seqlens_k | ||
max_seqlen_in_batch_q = max_seqlen_in_batch_k | ||
indices_q = indices_k | ||
elif query_length == 1: | ||
max_seqlen_in_batch_q = 1 | ||
cu_seqlens_q = torch.arange( | ||
batch_size + 1, dtype=torch.int32, device=query_layer.device | ||
) # There is a memcpy here, that is very bad. | ||
indices_q = cu_seqlens_q[:-1] | ||
query_layer = query_layer.squeeze(1) | ||
else: | ||
# The -q_len: slice assumes left padding. | ||
attention_mask = attention_mask[:, -query_length:] | ||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | ||
|
||
return ( | ||
query_layer, | ||
key_layer, | ||
value_layer, | ||
indices_q, | ||
(cu_seqlens_q, cu_seqlens_k), | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | ||
) | ||
|
||
|
||
def _flash_attention_forward( | ||
query_states: torch.Tensor, | ||
key_states: torch.Tensor, | ||
value_states: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
query_length: int, | ||
is_causal: bool, | ||
dropout: float = 0.0, | ||
softmax_scale: Optional[float] = None, | ||
sliding_window: Optional[int] = None, | ||
use_top_left_mask: bool = False, | ||
softcap: Optional[float] = None, | ||
): | ||
""" | ||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | ||
first unpad the input, then computes the attention scores and pad the final attention scores. | ||
Args: | ||
query_states (`torch.Tensor`): | ||
Input query states to be passed to Flash Attention API | ||
key_states (`torch.Tensor`): | ||
Input key states to be passed to Flash Attention API | ||
value_states (`torch.Tensor`): | ||
Input value states to be passed to Flash Attention API | ||
attention_mask (`torch.Tensor`): | ||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | ||
position of padding tokens and 1 for the position of non-padding tokens. | ||
dropout (`float`): | ||
Attention dropout | ||
softmax_scale (`float`, *optional*): | ||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | ||
use_top_left_mask (`bool`, defaults to `False`): | ||
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. | ||
softcap (`float`, *optional*): | ||
Softcap for the attention logits, used e.g. in gemma2. | ||
""" | ||
if not use_top_left_mask: | ||
causal = is_causal | ||
else: | ||
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. | ||
causal = is_causal and query_length != 1 | ||
|
||
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). | ||
use_sliding_windows = ( | ||
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window | ||
) | ||
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {} | ||
|
||
if softcap is not None: | ||
flash_kwargs["softcap"] = softcap | ||
|
||
# Contains at least one padding token in the sequence | ||
if attention_mask is not None: | ||
batch_size = query_states.shape[0] | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( | ||
query_states, key_states, value_states, attention_mask, query_length | ||
) | ||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
||
attn_output_unpad = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=causal, | ||
**flash_kwargs, | ||
) | ||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
else: | ||
attn_output = flash_attn_func( | ||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs | ||
) | ||
|
||
return attn_output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.