You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/speeding_up_training.md
+136-5Lines changed: 136 additions & 5 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,12 +1,11 @@
1
1
# Speeding Up Training
2
2
3
-
> [!WARNING]
4
-
> Section under construction. Feel free to contribute!
3
+
This guide covers various methods to accelerate training in TRL. Each technique includes minimal examples with links to more comprehensive documentation.
5
4
6
5
## vLLM for fast generation in online methods
7
6
8
-
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
9
-
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.
7
+
[Online methods](index#online-methods) such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
8
+
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. For more details, see [vLLM Integration](vllm_integration).
10
9
11
10
To use [vLLM](https://github.com/vllm-project/vllm), first install it using:
12
11
@@ -17,7 +16,13 @@ pip install trl[vllm]
17
16
<hfoptionsid="vllm examples">
18
17
<hfoptionid="Online DPO">
19
18
20
-
Then, enable it by passing `use_vllm=True` in the training arguments.
19
+
First, start a vLLM server by running:
20
+
21
+
```bash
22
+
trl vllm-serve --model <model_name>
23
+
```
24
+
25
+
Then, run the training script and pass `use_vllm=True` in the training arguments.
21
26
22
27
```python
23
28
from trl.experimental.online_dpo import OnlineDPOConfig
@@ -95,3 +100,129 @@ You can customize the server configuration by passing additional arguments. For
95
100
96
101
</hfoption>
97
102
</hfoptions>
103
+
104
+
## Optimized attention implementations
105
+
106
+
TRL supports various optimized attention implementations that can significantly speed up training while reducing memory usage. You can use either locally installed backends (like Flash Attention 2) or pull pre-optimized kernels directly from the [Kernels Hub](kernels_hub).
107
+
108
+
<hfoptions id="attention examples">
109
+
<hfoption id="Flash Attention 2">
110
+
111
+
To enable Flash Attention 2, pass `attn_implementation="flash_attention_2"`in the model initialization arguments:
Other options include `kernels-community/vllm-flash-attn3` and `kernels-community/paged-attention`.
131
+
132
+
</hfoption>
133
+
</hfoptions>
134
+
135
+
Optimized attention works across all TRL trainers. For more details, see [Kernels Hub Integration](kernels_hub) and [Reducing Memory Usage](reducing_memory_usage#padding-free).
136
+
137
+
## PEFT for parameter-efficient training
138
+
139
+
[PEFT](https://huggingface.co/docs/peft/index) (Parameter-Efficient Fine-Tuning) methods like LoRA significantly reduce memory usage and training time by only training a small number of adapter parameters instead of the full model.
140
+
141
+
```python
142
+
from peft import LoraConfig
143
+
from trl import SFTConfig, SFTTrainer
144
+
145
+
peft_config = LoraConfig(
146
+
r=16,
147
+
lora_alpha=32,
148
+
lora_dropout=0.05,
149
+
target_modules=["q_proj", "v_proj"],
150
+
)
151
+
152
+
trainer = SFTTrainer(
153
+
model="Qwen/Qwen2.5-0.5B",
154
+
peft_config=peft_config,
155
+
args=training_args,
156
+
)
157
+
```
158
+
159
+
For more details, see [PEFT Integration](peft_integration).
160
+
161
+
## Liger Kernel for memory optimization
162
+
163
+
Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%.
For more information, see [Liger Kernel Integration](liger_kernel_integration).
205
+
206
+
## Gradient checkpointing for memory savings
207
+
208
+
Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass, recomputing them during the backward pass instead.
Gradient checkpointing is available across all TRL trainers. For more memory optimization techniques, see the [Transformers Performance Guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing).
217
+
218
+
## Mixed precision training
219
+
220
+
Mixed precision training using bf16 or fp16 can speed up training and reduce memory usage with minimal impact on model quality.
221
+
222
+
```python
223
+
from trl import SFTConfig
224
+
225
+
training_args = SFTConfig(..., bf16=True) # or fp16=True for older GPUs
226
+
```
227
+
228
+
Use `bf16=True`for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True`for older GPUs. Mixed precision training is supported across all TRL trainers.
0 commit comments