Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor flash attention implementation in transformers #31446

Merged
merged 62 commits into from
Jul 11, 2024
Merged
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
b66fdb0
dumb commit
ArthurZucker May 31, 2024
029ee11
nit
ArthurZucker Jun 17, 2024
9a7885d
Merge branch 'main' into backend-compatible
ArthurZucker Jun 17, 2024
ac3e5b5
update
ArthurZucker Jun 17, 2024
a7c48bd
something like this
ArthurZucker Jun 17, 2024
682f221
unpack in modeling utils
ArthurZucker Jun 17, 2024
2201178
safe import
ArthurZucker Jun 17, 2024
55a3503
oups
ArthurZucker Jun 21, 2024
b5cbaef
update
ArthurZucker Jun 26, 2024
7c6fdd7
nits
ArthurZucker Jun 26, 2024
08d7e1e
diff convert gemma
ArthurZucker Jun 26, 2024
27044da
update
ArthurZucker Jun 26, 2024
ca316a0
start propagating
ArthurZucker Jun 26, 2024
ea93267
udpate other modeling code as well
ArthurZucker Jun 26, 2024
4b67223
update for sliding window models
ArthurZucker Jun 26, 2024
d59ac0c
nits
ArthurZucker Jun 26, 2024
a1d3866
more init cleanups
ArthurZucker Jun 26, 2024
aea7f03
styling
ArthurZucker Jun 26, 2024
f1bedd0
fixup
ArthurZucker Jun 26, 2024
86e2edc
noice
ArthurZucker Jun 26, 2024
e90a944
pass fixup
ArthurZucker Jun 26, 2024
093fbf5
typo typing_extension -> typing_extensions
ArthurZucker Jun 26, 2024
a1c56d2
torch.nn.functionnal -> torch.nn.functional
ArthurZucker Jun 26, 2024
1aad4a2
add to import structure
ArthurZucker Jun 26, 2024
10bc1fa
unpack
ArthurZucker Jun 26, 2024
9f08ddb
simplify a bit more for this first version
ArthurZucker Jun 26, 2024
2e65e57
nut
ArthurZucker Jun 26, 2024
f8622e6
update
ArthurZucker Jun 26, 2024
2bb4347
update
ArthurZucker Jun 26, 2024
9be7579
nit
ArthurZucker Jun 26, 2024
889cbf8
ease the import of `Unpack`
ArthurZucker Jun 26, 2024
070af2d
remove useless `use_sliding_window`
ArthurZucker Jun 26, 2024
80057a0
no qua please
ArthurZucker Jun 26, 2024
c0b024d
protect import?
ArthurZucker Jun 26, 2024
8f7d1c1
style
ArthurZucker Jun 26, 2024
46b77f9
[run-slow]
ArthurZucker Jun 26, 2024
4a98ee7
[run slow] llama,gemma,mistral,mixtral
ArthurZucker Jun 26, 2024
25b2c10
remove extra kwargs
ArthurZucker Jun 26, 2024
8c3780d
Merge branch 'main' of github.com:huggingface/transformers into backe…
ArthurZucker Jun 26, 2024
1d38dab
fix llama
ArthurZucker Jun 26, 2024
f64864a
address review comments
fxmarty Jul 1, 2024
565c5dc
apply diff_model_converter to modeling_gemma.py
fxmarty Jul 1, 2024
2403ce5
Merge branch 'main' into backend-compatible
fxmarty Jul 1, 2024
c89571d
remove cache_position 1
fxmarty Jul 2, 2024
32c2df8
remove cache_position 2
fxmarty Jul 2, 2024
54a9fb0
some cleaning
fxmarty Jul 2, 2024
206731e
refactor gemma2 as well
fxmarty Jul 2, 2024
7c65fc7
Merge branch 'main' into backend-compatible
fxmarty Jul 2, 2024
1be8c31
apply review comments
fxmarty Jul 3, 2024
8d181ea
rename file to modeling_flash_attention_utils.py
fxmarty Jul 3, 2024
3a30cb6
Merge branch 'main' into backend-compatible
fxmarty Jul 8, 2024
c92028a
siglip refactor
fxmarty Jul 8, 2024
7243993
remove dead code
fxmarty Jul 8, 2024
8b077d8
is the hub down?
fxmarty Jul 8, 2024
a9796bc
still down?
fxmarty Jul 9, 2024
6752a9c
fix siglip
fxmarty Jul 10, 2024
3a9cf1b
Merge branch 'main' into backend-compatible
fxmarty Jul 11, 2024
b4d1df5
fix gemma2
fxmarty Jul 11, 2024
1e1bc2f
fatal: Could not read from remote repository.
fxmarty Jul 11, 2024
c79ca83
fix typo in softcap implem
fxmarty Jul 11, 2024
30dc123
flacky
fxmarty Jul 11, 2024
fae6843
Failed: Timeout >120.0s
fxmarty Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' into backend-compatible
  • Loading branch information
fxmarty committed Jul 11, 2024
commit 3a9cf1b279a436c96a8ea7811dd76ff3b58989cb
7 changes: 4 additions & 3 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand All @@ -40,7 +42,6 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
Expand Down Expand Up @@ -359,8 +360,7 @@ def forward(
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING
attn_output = _flash_attention_forward(
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
Expand All @@ -370,6 +370,7 @@ def forward(
softmax_scale=self.scaling,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping,
)

attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.