Skip to content

Commit c7d172b

Browse files
docs: Expand speeding up training guide with acceleration methods (#4428)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
1 parent f1dfef0 commit c7d172b

File tree

1 file changed

+136
-5
lines changed

1 file changed

+136
-5
lines changed

docs/source/speeding_up_training.md

Lines changed: 136 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
# Speeding Up Training
22

3-
> [!WARNING]
4-
> Section under construction. Feel free to contribute!
3+
This guide covers various methods to accelerate training in TRL. Each technique includes minimal examples with links to more comprehensive documentation.
54

65
## vLLM for fast generation in online methods
76

8-
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
9-
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.
7+
[Online methods](index#online-methods) such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
8+
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. For more details, see [vLLM Integration](vllm_integration).
109

1110
To use [vLLM](https://github.com/vllm-project/vllm), first install it using:
1211

@@ -17,7 +16,13 @@ pip install trl[vllm]
1716
<hfoptions id="vllm examples">
1817
<hfoption id="Online DPO">
1918

20-
Then, enable it by passing `use_vllm=True` in the training arguments.
19+
First, start a vLLM server by running:
20+
21+
```bash
22+
trl vllm-serve --model <model_name>
23+
```
24+
25+
Then, run the training script and pass `use_vllm=True` in the training arguments.
2126

2227
```python
2328
from trl.experimental.online_dpo import OnlineDPOConfig
@@ -95,3 +100,129 @@ You can customize the server configuration by passing additional arguments. For
95100
96101
</hfoption>
97102
</hfoptions>
103+
104+
## Optimized attention implementations
105+
106+
TRL supports various optimized attention implementations that can significantly speed up training while reducing memory usage. You can use either locally installed backends (like Flash Attention 2) or pull pre-optimized kernels directly from the [Kernels Hub](kernels_hub).
107+
108+
<hfoptions id="attention examples">
109+
<hfoption id="Flash Attention 2">
110+
111+
To enable Flash Attention 2, pass `attn_implementation="flash_attention_2"` in the model initialization arguments:
112+
113+
```python
114+
from trl import SFTConfig
115+
116+
training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "flash_attention_2"})
117+
```
118+
119+
</hfoption>
120+
<hfoption id="Kernels from Hub">
121+
122+
You can use pre-optimized attention kernels from the Hub without manual compilation:
123+
124+
```python
125+
from trl import SFTConfig
126+
127+
training_args = SFTConfig(..., model_init_kwargs={"attn_implementation": "kernels-community/flash-attn2"})
128+
```
129+
130+
Other options include `kernels-community/vllm-flash-attn3` and `kernels-community/paged-attention`.
131+
132+
</hfoption>
133+
</hfoptions>
134+
135+
Optimized attention works across all TRL trainers. For more details, see [Kernels Hub Integration](kernels_hub) and [Reducing Memory Usage](reducing_memory_usage#padding-free).
136+
137+
## PEFT for parameter-efficient training
138+
139+
[PEFT](https://huggingface.co/docs/peft/index) (Parameter-Efficient Fine-Tuning) methods like LoRA significantly reduce memory usage and training time by only training a small number of adapter parameters instead of the full model.
140+
141+
```python
142+
from peft import LoraConfig
143+
from trl import SFTConfig, SFTTrainer
144+
145+
peft_config = LoraConfig(
146+
r=16,
147+
lora_alpha=32,
148+
lora_dropout=0.05,
149+
target_modules=["q_proj", "v_proj"],
150+
)
151+
152+
trainer = SFTTrainer(
153+
model="Qwen/Qwen2.5-0.5B",
154+
peft_config=peft_config,
155+
args=training_args,
156+
)
157+
```
158+
159+
For more details, see [PEFT Integration](peft_integration).
160+
161+
## Liger Kernel for memory optimization
162+
163+
Liger Kernel is a collection of Triton kernels designed for LLM training that can increase throughput by 20% and reduce memory usage by 60%.
164+
165+
<hfoptions id="liger examples">
166+
<hfoption id="SFT">
167+
168+
```python
169+
from trl import SFTConfig
170+
171+
training_args = SFTConfig(..., use_liger_kernel=True)
172+
```
173+
174+
</hfoption>
175+
<hfoption id="DPO">
176+
177+
```python
178+
from trl import DPOConfig
179+
180+
training_args = DPOConfig(..., use_liger_kernel=True)
181+
```
182+
183+
</hfoption>
184+
<hfoption id="GRPO">
185+
186+
```python
187+
from trl import GRPOConfig
188+
189+
training_args = GRPOConfig(..., use_liger_kernel=True)
190+
```
191+
192+
</hfoption>
193+
<hfoption id="KTO">
194+
195+
```python
196+
from trl import KTOConfig
197+
198+
training_args = KTOConfig(..., use_liger_kernel=True)
199+
```
200+
201+
</hfoption>
202+
</hfoptions>
203+
204+
For more information, see [Liger Kernel Integration](liger_kernel_integration).
205+
206+
## Gradient checkpointing for memory savings
207+
208+
Gradient checkpointing trades compute for memory by not storing all intermediate activations during the forward pass, recomputing them during the backward pass instead.
209+
210+
```python
211+
from trl import SFTConfig
212+
213+
training_args = SFTConfig(..., gradient_checkpointing=True)
214+
```
215+
216+
Gradient checkpointing is available across all TRL trainers. For more memory optimization techniques, see the [Transformers Performance Guide](https://huggingface.co/docs/transformers/perf_train_gpu_one#gradient-checkpointing).
217+
218+
## Mixed precision training
219+
220+
Mixed precision training using bf16 or fp16 can speed up training and reduce memory usage with minimal impact on model quality.
221+
222+
```python
223+
from trl import SFTConfig
224+
225+
training_args = SFTConfig(..., bf16=True) # or fp16=True for older GPUs
226+
```
227+
228+
Use `bf16=True` for Ampere GPUs (A100, RTX 30xx) or newer, and `fp16=True` for older GPUs. Mixed precision training is supported across all TRL trainers.

0 commit comments

Comments
 (0)