forked from huggingface/trl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FEAT: Add CLIs in TRL ! (huggingface#1419)
* CLI V1 * v1 CLI * add rich enhancmeents * revert unindented change * some comments * cleaner CLI * fix * fix * remove print callback * move to cli instead of trl_cli * revert unneeded changes * fix test * Update trl/commands/sft.py Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> * remove redundant strings * fix import issue * fix other issues * add packing * add config parser * some refactor * cleaner * add example config yaml file * small refactor * change a bit the logic * fix issues here and there * add CLI in docs * move to examples/sft * remove redundant licenses * make it work on dpo * set to None * switch to accelerate and fix many things * add docs * more docs * added tests * doc clarification * more docs * fix CI for windows and python 3.8 * fix * attempt to fix CI * fix? * test * fix * tweak? * fix * test * another test * fix * test * fix * fix * fix * skip tests for windows * test @lvwerra approach * make dev * revert unneeded changes * fix sft dpo * optimize a bit * address final comments * update docs * final comment --------- Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
- Loading branch information
1 parent
304e208
commit a2aa0f0
Showing
24 changed files
with
1,085 additions
and
233 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Command Line Interfaces (CLIs) | ||
|
||
You can use TRL to fine-tune your Language Model on Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the TRL CLIs. | ||
|
||
Currently supported CLIs are: | ||
|
||
- `trl sft` | ||
- `trl dpo` | ||
|
||
## Get started | ||
|
||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task. | ||
|
||
Also make sure to run: | ||
```bash | ||
accelerate config | ||
``` | ||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command. | ||
|
||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command. | ||
|
||
```yaml | ||
model_name_or_path: | ||
HuggingFaceM4/tiny-random-LlamaForCausalLM | ||
dataset_name: | ||
imdb | ||
dataset_text_field: | ||
text | ||
report_to: | ||
none | ||
learning_rate: | ||
0.0001 | ||
lr_scheduler_type: | ||
cosine | ||
``` | ||
Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.: | ||
|
||
```bash | ||
trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts | ||
``` | ||
|
||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`. | ||
|
||
## Supported Arguments | ||
|
||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`: | ||
|
||
[[autodoc]] ModelConfig | ||
|
||
You can pass any of these arguments either to the CLI or the YAML file. | ||
|
||
### Supervised Fine-tuning (SFT) | ||
|
||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`: | ||
|
||
```bash | ||
trl sft --config config.yaml --output_dir your-output-dir | ||
``` | ||
|
||
The SFT CLI is based on the `examples/scripts/sft.py` script. | ||
|
||
### Direct Policy Optimization (DPO) | ||
|
||
First, follow the basic instructions above and run `trl dpo --output_dir <output_dir> <*args>`. Make sure to process your DPO dataset in the TRL format as follows: | ||
|
||
1- Make sure to pre-tokenize the dataset using chat templates: | ||
|
||
```bash | ||
python examples/datasets/tokenize_ds.py --model gpt2 --dataset yourdataset | ||
``` | ||
|
||
You might need to adapt the `examples/datasets/tokenize_ds.py` to use yout chat template | ||
|
||
2- Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`): | ||
|
||
```bash | ||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org | ||
``` | ||
|
||
Once your dataset being pushed, run the dpo CLI as follows: | ||
|
||
```bash | ||
trl dpo --config config.yaml --output_dir your-output-dir | ||
``` | ||
|
||
The SFT CLI is based on the `examples/scripts/dpo.py` script. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# This is an example configuration file of TRL CLI, you can use it for | ||
# SFT like that: `trl sft --config config.yaml --output_dir test-sft` | ||
# The YAML file supports environment variables by adding an `env` field | ||
# as below | ||
|
||
# env: | ||
# CUDA_VISIBLE_DEVICES: 0 | ||
|
||
model_name_or_path: | ||
HuggingFaceM4/tiny-random-LlamaForCausalLM | ||
dataset_name: | ||
imdb | ||
dataset_text_field: | ||
text | ||
report_to: | ||
none | ||
learning_rate: | ||
1e-4 | ||
lr_scheduler_type: | ||
cosine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.