Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change KTO tokenization to use DPO's #2187

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9d5105c
add argument for dropout
kawine Oct 1, 2024
1996fe0
increase default lr
kawine Oct 1, 2024
000bd35
change default lr in examples
kawine Oct 1, 2024
0fbcec6
fix bug in calculation of KL batch size
kawine Oct 1, 2024
847790d
KL batch size should be args.per_device_train_batch_size
kawine Oct 1, 2024
4a8645a
Update kto_trainer.mdx with hparam recs
kawine Oct 1, 2024
8aab352
typo
kawine Oct 1, 2024
599d6f6
allow dropout to be disabled
kawine Oct 1, 2024
0b2fda2
Merge branch 'huggingface:main' into kto-hyperparam
kawine Oct 1, 2024
67d7884
Use DPO tokenization functions where possible
kawine Oct 2, 2024
f7d77b6
fix bugs in use of dpotrainer tokenization
kawine Oct 2, 2024
ea23270
Merge branch 'huggingface:main' into kto-tokenize
kawine Oct 2, 2024
0e23f55
add prefixes and text to batch
kawine Oct 2, 2024
ecfa423
minor changes
kawine Oct 5, 2024
62dd49f
Merge branch 'kto-tokenize' of https://github.com/kawine/trl into kto…
kawine Oct 5, 2024
1157202
Merge branch 'huggingface:main' into kto-tokenize
kawine Oct 5, 2024
f0324a6
minor changes
kawine Oct 6, 2024
61365db
remove unnecessarily cols in kl dataset
kawine Oct 6, 2024
9fabdca
Update trl/trainer/kto_trainer.py
kashif Oct 6, 2024
b0ee0d8
formatting
kashif Oct 6, 2024
ea88ad1
Merge branch 'main' into kto-tokenize
kashif Oct 6, 2024
2f83f58
revert from merge
kashif Oct 6, 2024
42c5e92
fix tests to work with new tokenization format
kawine Oct 6, 2024
782cc51
fix test
kashif Oct 6, 2024
c83cf4a
add back maybe_unpair_preference_dataset
kashif Oct 6, 2024
b201d63
Update examples/scripts/kto.py
kashif Oct 6, 2024
729d59e
Update examples/scripts/kto.py
kashif Oct 6, 2024
9ae7f9f
Merge branch 'main' into kto-tokenize
kawine Oct 8, 2024
7c3b970
remove twice processing of training data
kawine Oct 8, 2024
a9e87b5
fix more bugs with merge
kawine Oct 8, 2024
cc36a81
Merge branch 'main' into kto-tokenize
kawine Oct 9, 2024
c7f9dda
move tokenization helper functions to utils; streamline KL calc for KTO
kawine Oct 10, 2024
697dae2
Merge branch 'main' into kto-tokenize
kawine Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions docs/source/kto_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ For a full example have a look at [`examples/scripts/kto.py`].

Depending on how good your base model is, you may or may not need to do SFT before KTO.
This is different from standard RLHF and DPO, which always require SFT.
You can also train with imbalanced data (more chosen than rejected examples, or vice-versa), but you will need to adjust hyperparameters accordingly (see below).

## Expected dataset format

Expand Down Expand Up @@ -51,7 +52,8 @@ kto_dataset_dict = {
```

where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
In theory, the dataset must contain at least one desirable and one undesirable completion; however, some people have had success running KTO on _only_ desirable or undesirable data (in the latter case, it is best to use a conservative learning rate).


## Expected model format
Expand All @@ -61,13 +63,17 @@ The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that

For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.

The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
The `beta` refers to the hyperparameter that controls how quickly the loss saturates, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).

The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.

<Tip>
It is strongly recommended you use a learning rate between `5e-7` and `5e-6` with an effective batch size between `8` and `32`, for both LoRA and full finetuning. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, using smaller batch sizes and/or more training epochs will give you better results.
Every choice of `beta` has a maximum learning rate it will tolerate before learning degenerates. For the default `beta = 0.1', this learning rate is `1e-6` for most models. The lower beta is, the lower your learning rate should be. In general, we strongly recommend a learning rate between `5e-7` and `5e-6`. Even if you are working with a small dataset, we do not recommend using a learning rate outside this range; instead, use more epochs.
</Tip>

<Tip>
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
</Tip>

```py
Expand Down
10 changes: 5 additions & 5 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 16 \
--num_train_epochs 1 \
--learning_rate 5e-7 \
--learning_rate 1e-6 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
Expand All @@ -36,7 +36,7 @@
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--learning_rate 5e-7 \
--learning_rate 1e-6 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
Expand Down Expand Up @@ -98,16 +98,16 @@ class ScriptArguments:
dataset = load_dataset(script_args.dataset_name)

# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)
# dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)

# Apply chat template
def format_dataset(example):
if isinstance(example["completion"], str):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
else:
example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False)
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
return example

# Compute that only on the main process for faster data processing.
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class KTOConfig(TrainingArguments):
Number of processes to use for processing the dataset.
"""

learning_rate: float = 5e-7
learning_rate: float = 1e-6
max_length: Optional[int] = None
max_prompt_length: Optional[int] = None
max_completion_length: Optional[int] = None
Expand All @@ -90,6 +90,7 @@ class KTOConfig(TrainingArguments):
truncation_mode: str = "keep_end"
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
disable_dropout: bool = True
precompute_ref_log_probs: bool = False
model_init_kwargs: Optional[Dict[str, Any]] = None
ref_model_init_kwargs: Optional[Dict[str, Any]] = None
Expand Down
Loading
Loading