Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mistral] Add Flash Attention-2 support for mistral #26464

Merged
merged 28 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5d9bc48
add FA-2 support for mistral
younesbelkada Sep 28, 2023
0983d88
fixup
younesbelkada Sep 28, 2023
b8c9198
add sliding windows
younesbelkada Sep 28, 2023
d740849
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Sep 28, 2023
bd58ca7
fixing few nits
younesbelkada Sep 28, 2023
43b0289
v1 slicing cache - logits do not match
younesbelkada Sep 29, 2023
ed2616f
add comment
younesbelkada Sep 29, 2023
7cafc2d
fix bugs
younesbelkada Oct 2, 2023
2b8c7b4
more mem efficient
younesbelkada Oct 2, 2023
4a3387d
add warning once
younesbelkada Oct 2, 2023
885b601
add warning once
younesbelkada Oct 2, 2023
172d99a
oops
younesbelkada Oct 2, 2023
253b383
fixup
younesbelkada Oct 2, 2023
e4d0fb7
more comments
younesbelkada Oct 2, 2023
a245722
copy
younesbelkada Oct 2, 2023
3079896
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 2, 2023
e71c50d
add safety checker
younesbelkada Oct 2, 2023
a21d903
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 2, 2023
5d1f589
fixup
younesbelkada Oct 2, 2023
b478e04
Update src/transformers/models/mistral/modeling_mistral.py
younesbelkada Oct 3, 2023
2fe2f49
copied from
younesbelkada Oct 3, 2023
25789d1
up
younesbelkada Oct 3, 2023
05ec7f4
raise when padding side is right
younesbelkada Oct 3, 2023
5a79195
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 3, 2023
f9a69bc
fixup
younesbelkada Oct 3, 2023
6a48dd3
add doc + few minor changes
younesbelkada Oct 3, 2023
76763c7
Merge branch 'add-mistral-fa-2' of https://github.com/younesbelkada/t…
younesbelkada Oct 3, 2023
c286946
fixup
younesbelkada Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions docs/source/en/model_doc/mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,51 @@ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = MistralForCausalLM.from_pretrained("/output/path")
```

## Combining Mistral and Flash Attention 2

First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

```bash
pip install -U flash-attn --no-build-isolation
```

Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) repository. Make also sure to load your model in half-precision (e.g. `torch.float16`)

To load and run a model using Flash Attention 2, refer to the snippet below:

```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

>>> prompt = "My favourite condiment is"

>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
>>> model.to(device)

>>> generated_ids = model.generate(**model_inputs, max_new_tokens=100, do_sample=True)
>>> tokenizer.batch_decode(generated_ids)[0]
"The expected outupt"
```

### Expected speedups

Below is a expected speedup diagram that compares pure inference time between the native implementation in transformers using `mistralai/Mistral-7B-v0.1` checkpoint and the Flash Attention 2 version of the model.

<div style="text-align: center">
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/mistral-7b-inference-large-seqlen.png">
</div>

### Sliding window Attention

The current implementation supports the sliding window attention mechanism and memory efficient cache management.
To enable sliding window attention, just make sure to have a `flash-attn` version that is compatible with sliding window attention (`>=2.3.0`).

The Flash Attention-2 model uses also a more memory efficient cache slicing mechanism - as recommended per the official implementation of Mistral model that use rolling cache mechanism we keep the cache size fixed (`self.config.sliding_window`), support batched generation only for `padding_side="left"` and use the absolute position of the current token to compute the positional embedding.

## The Mistral Team

Albert Jiang, Alexandre Sablayrolles, Arthur Mensch, Chris Bamford, Devendra Singh Chaplot, Diego de las Casas, Florian Bressand, Gianna Lengyel, Guillaume Lample, Lélio Renard Lavaud, Lucile Saulnier, Marie-Anne Lachaux, Pierre Stock, Teven Le Scao, Thibaut Lavril, Thomas Wang, Timothée Lacroix, William El Sayed.
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Make sure to follow the installation guide on the repository mentioned above to
We natively support Flash Attention 2 for the following models:

- Llama
- Mistral
ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
- Falcon

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*
Expand Down
Loading
Loading