Skip to content

Commit 102dc41

Browse files
Rename flash-attn to flash-attn2 (#4514)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
1 parent 5de62b0 commit 102dc41

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

docs/source/kernels_hub.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,20 @@ from transformers import AutoModelForCausalLM
2727

2828
model = AutoModelForCausalLM.from_pretrained(
2929
"your-model-name",
30-
attn_implementation="kernels-community/flash-attn" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
30+
attn_implementation="kernels-community/flash-attn2" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
3131
)
3232
```
3333

3434
Or when running a TRL training script:
3535

3636
```bash
37-
python sft.py ... --attn_implementation kernels-community/flash-attn
37+
python sft.py ... --attn_implementation kernels-community/flash-attn2
3838
```
3939

4040
Or using the TRL CLI:
4141

4242
```bash
43-
trl sft ... --attn_implementation kernels-community/flash-attn
43+
trl sft ... --attn_implementation kernels-community/flash-attn2
4444
```
4545

4646
> [!TIP]
@@ -84,7 +84,7 @@ from trl import SFTConfig
8484

8585
model = AutoModelForCausalLM.from_pretrained(
8686
"your-model-name",
87-
attn_implementation="kernels-community/flash-attn" # choose the desired FlashAttention variant
87+
attn_implementation="kernels-community/flash-attn2" # choose the desired FlashAttention variant
8888
)
8989

9090
training_args = SFTConfig(

trl/trainer/model_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ class ModelConfig:
4343
be set to `True` for repositories you trust and in which you have read the code, as it will execute code
4444
present on the Hub on your local machine.
4545
attn_implementation (`str`, *optional*):
46-
Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case
47-
you must install this manually by running `pip install flash-attn --no-build-isolation`.
46+
Which attention implementation to use. More information in the [Kernels Hub Integrations
47+
Guide](kernels_hub).
4848
use_peft (`bool`, *optional*, defaults to `False`):
4949
Whether to use PEFT for training.
5050
lora_r (`int`, *optional*, defaults to `16`):

trl/trainer/sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@
7272
FLASH_ATTENTION_VARIANTS = {
7373
"flash_attention_2",
7474
"flash_attention_3",
75-
"kernels-community/flash-attn",
76-
"kernels-community/vllm-flash-attn3",
75+
"kernels-community/flash-attn2",
7776
"kernels-community/flash-attn3",
77+
"kernels-community/vllm-flash-attn3",
7878
}
7979

8080

0 commit comments

Comments
 (0)