File tree Expand file tree Collapse file tree 3 files changed +8
-8
lines changed
Expand file tree Collapse file tree 3 files changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -27,20 +27,20 @@ from transformers import AutoModelForCausalLM
2727
2828model = AutoModelForCausalLM.from_pretrained(
2929 " your-model-name" ,
30- attn_implementation = " kernels-community/flash-attn " # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
30+ attn_implementation = " kernels-community/flash-attn2 " # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
3131)
3232```
3333
3434Or when running a TRL training script:
3535
3636``` bash
37- python sft.py ... --attn_implementation kernels-community/flash-attn
37+ python sft.py ... --attn_implementation kernels-community/flash-attn2
3838```
3939
4040Or using the TRL CLI:
4141
4242``` bash
43- trl sft ... --attn_implementation kernels-community/flash-attn
43+ trl sft ... --attn_implementation kernels-community/flash-attn2
4444```
4545
4646> [ !TIP]
@@ -84,7 +84,7 @@ from trl import SFTConfig
8484
8585model = AutoModelForCausalLM.from_pretrained(
8686 " your-model-name" ,
87- attn_implementation = " kernels-community/flash-attn " # choose the desired FlashAttention variant
87+ attn_implementation = " kernels-community/flash-attn2 " # choose the desired FlashAttention variant
8888)
8989
9090training_args = SFTConfig(
Original file line number Diff line number Diff line change @@ -43,8 +43,8 @@ class ModelConfig:
4343 be set to `True` for repositories you trust and in which you have read the code, as it will execute code
4444 present on the Hub on your local machine.
4545 attn_implementation (`str`, *optional*):
46- Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case
47- you must install this manually by running `pip install flash-attn --no-build-isolation` .
46+ Which attention implementation to use. More information in the [Kernels Hub Integrations
47+ Guide](kernels_hub) .
4848 use_peft (`bool`, *optional*, defaults to `False`):
4949 Whether to use PEFT for training.
5050 lora_r (`int`, *optional*, defaults to `16`):
Original file line number Diff line number Diff line change 7272FLASH_ATTENTION_VARIANTS = {
7373 "flash_attention_2" ,
7474 "flash_attention_3" ,
75- "kernels-community/flash-attn" ,
76- "kernels-community/vllm-flash-attn3" ,
75+ "kernels-community/flash-attn2" ,
7776 "kernels-community/flash-attn3" ,
77+ "kernels-community/vllm-flash-attn3" ,
7878}
7979
8080
You can’t perform that action at this time.
0 commit comments