Skip to content

Commit

Permalink
Conversational dataset support for DPOTrainer (huggingface#2131)
Browse files Browse the repository at this point in the history
* conversational dataset support for dpo

* support standard dataset for extract prompt

* test standard dataset for extract prompt

* fix maybe

* fix maybe apply prompt

* style

* overwrite default learning rate of DPO

* style

* rlaif script

* `writer_batch_size` in `train_test_split`

* initial dpo doc refactoring

* vision data section in doc

* lil format modif

* refine Vision datasets

* refine doc

* test new loss type format

* restrcture loss function

* table loss type

* simplify `unsloth`

* improve doc

* looged metrics up

* refine loss section

* Fix label_smoothing parameter in DPOConfig

* dataset for test

* update readme

* Update docs/source/dpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* try colorized code block

* refine doc style

* further refine doc

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* re add pali gemma test

* Add missing period

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent 5c21de3 commit 78249d9
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 223 deletions.
15 changes: 3 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,24 +181,15 @@ trainer.train()
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example on how to use the `DPOTrainer`:

```python
from trl import DPOConfig, DPOTrainer, maybe_extract_prompt, maybe_apply_chat_template
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/Capybara-Preferences", split="train")
dataset = dataset.map(maybe_extract_prompt)
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": tokenizer})

training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
args=training_args,
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
)
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, tokenizer=tokenizer)
trainer.train()
```

Expand Down
34 changes: 34 additions & 0 deletions docs/source/dataset_formats.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": "
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
```

Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.

### Unpaired preference

An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
Expand Down Expand Up @@ -710,3 +712,35 @@ dataset = dataset.remove_columns(["completion", "label"])
>>> dataset[0]
{'prompt': 'The sky is'}
```

## Vision datasets

Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.

A conversational vision dataset differs from a standard conversational dataset in two key ways:

1. The dataset must contain the key `images` with the image data.
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.

Example:

```python
# Textual dataset format:
"content": "What color is the sky?"

# Vision dataset format:
"content": [
{"type": "image"},
{"type": "text", "text": "What color is the sky in the image?"}
]
```

An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:

<iframe
src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>

301 changes: 137 additions & 164 deletions docs/source/dpo_trainer.mdx

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions docs/source/online_dpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")

training_args = OnlineDPOConfig(output_dir="online-dpo-qwen2", logging_steps=10)
trainer = OnlineDPOTrainer(
model=model,
reward_model=reward_model,
args=training_args,
tokenizer=tokenizer,
train_dataset=train_dataset,
model=model, reward_model=reward_model, args=training_args, tokenizer=tokenizer, train_dataset=train_dataset
)
trainer.train()
```
Expand Down
73 changes: 73 additions & 0 deletions examples/datasets/rlaif-v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional

from datasets import features, load_dataset
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""

push_to_hub: bool = False
repo_id: str = "trl-lib/rlaif-v"
dataset_num_proc: Optional[int] = None


def to_conversational(example):
"""
Convert prompt from "xxx" to [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "xxx"}]}]
and chosen and rejected from "xxx" to [{"role": "assistant", "content": [{"type": "text", "text": "xxx"}]}].
Images are wrapped into a list.
"""
prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}]
chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}]
rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}]
return {"prompt": prompt, "images": [example["image"]], "chosen": chosen, "rejected": rejected}


if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train")
dataset = dataset.map(
to_conversational,
num_proc=script_args.dataset_num_proc,
remove_columns=dataset.column_names,
writer_batch_size=128,
)

# Cast the images to Sequence[Image] to avoid bytes format
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True))
dataset = dataset.cast(f)

dataset = dataset.train_test_split(test_size=0.01, writer_batch_size=128)

if script_args.push_to_hub:
dataset.push_to_hub(script_args.repo_id)
9 changes: 0 additions & 9 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
"""

import torch
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand All @@ -60,8 +59,6 @@
get_kbit_device_map,
get_peft_config,
get_quantization_config,
maybe_apply_chat_template,
maybe_extract_prompt,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

Expand Down Expand Up @@ -115,12 +112,6 @@
################
dataset = load_dataset(script_args.dataset_name)

with PartialState().local_main_process_first():
dataset = dataset.map(maybe_extract_prompt, num_proc=training_args.dataset_num_proc)
dataset = dataset.map(
maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer}
)

##########
# Training
################
Expand Down
2 changes: 1 addition & 1 deletion tests/slow/test_dpo_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@require_torch_accelerator
class DPOTrainerSlowTester(unittest.TestCase):
def setUp(self):
self.dataset = load_dataset("trl-internal-testing/mlabonne-chatml-dpo-pairs-copy", split="train[:10%]")
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference")
self.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
Expand Down
60 changes: 49 additions & 11 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def test_maybe_unpair_preference_dataset_dict_already_paired(self):


class ExtractPromptTester(unittest.TestCase):
example_implicit_prompt = {
example_implicit_prompt_conversational = {
"chosen": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
Expand All @@ -279,7 +279,7 @@ class ExtractPromptTester(unittest.TestCase):
],
}

example_explicit_prompt = {
example_explicit_prompt_conversational = {
"prompt": [
{"role": "user", "content": "What color is the sky?"},
],
Expand All @@ -291,30 +291,68 @@ class ExtractPromptTester(unittest.TestCase):
],
}

def test_extract_prompt(self):
example_implicit_prompt_standard = {
"chosen": "The sky is blue.",
"rejected": "The sky is green.",
}

example_explicit_prompt_standard = {
"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green.",
}

def test_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_conversational(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_conversational_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt_conversational,
"The prompt should remain unchanged.",
)

def test_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset
example_extracted_prompt = extract_prompt(self.example_implicit_prompt)
example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt(self):
def test_maybe_extract_prompt_standard(self):
# Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt)
example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt is not correctly extracted from the dataset.",
)

def test_maybe_extract_prompt_already_explicit(self):
def test_maybe_extract_prompt_standard_already_explicit(self):
# Test that the prompt remains unchanged with maybe_extract_prompt
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt)
example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
self.assertEqual(
example_extracted_prompt,
self.example_explicit_prompt,
self.example_explicit_prompt_standard,
"The prompt should remain unchanged.",
)

Expand Down
11 changes: 1 addition & 10 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ class DPOVisionTrainerTester(unittest.TestCase):
@parameterized.expand(
[
["trl-internal-testing/tiny-random-idefics2"],
# ["trl-internal-testing/tiny-random-paligemma"], # temporarily disabled due to flaky tests
["trl-internal-testing/tiny-random-paligemma"],
["trl-internal-testing/tiny-random-llava-1.5"],
]
)
Expand Down Expand Up @@ -1094,15 +1094,6 @@ def test_vdpo_trainer(self, model_id):
ref_model = AutoModelForVision2Seq.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

# Apply chat template to the dataset
def apply_chat_template(example):
example["prompt"] = processor.apply_chat_template(example["prompt"])
example["chosen"] = processor.apply_chat_template(example["chosen"])
example["rejected"] = processor.apply_chat_template(example["rejected"])
return example

dataset = dataset.map(apply_chat_template)

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
Expand Down
22 changes: 13 additions & 9 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, TypeVar
from typing import Any, Dict, List, Optional, Sequence, TypeVar

from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
Expand Down Expand Up @@ -280,15 +280,17 @@ def maybe_unpair_preference_dataset(dataset: DatasetType, num_proc: Optional[int
return dataset


def extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
def extract_prompt(example: Dict[str, Sequence]) -> Dict[str, Sequence]:
r"""
Extracts the shared prompt from a preference data example, where the prompt is implicit within both
the chosen and rejected completions.
For more details, see [`maybe_extract_prompt`].
"""
for idx in range(min(len(example["chosen"]), len(example["rejected"]))):
if example["chosen"][idx]["content"] != example["rejected"][idx]["content"]:
if example["chosen"][idx] != example["rejected"][idx]:
if example["chosen"][idx - 1] == " ": # remove space before the prompt
idx -= 1
break
return {
"prompt": example["chosen"][:idx],
Expand All @@ -303,15 +305,14 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
the chosen and rejected completions.
If the example already contains a `"prompt"` key, the function returns the example as is. Else, the function
identifies the longest common sequence (prefix) of conversation turns between the "chosen" and "rejected"
completions and extracts this as the prompt. It then removes this prompt from the respective "chosen" and
"rejected" completions.
Args:
example (`Dict[str, List]`):
A dictionary representing a single data entry in the preference dataset. It must contain the keys
`"chosen"` and `"rejected"`, where each value is a list.
`"chosen"` and `"rejected"`, where each value is either conversational or standard (`str`).
Returns:
`Dict[str, List]`: A dictionary containing:
Expand Down Expand Up @@ -379,7 +380,10 @@ def maybe_extract_prompt(example: Dict[str, List]) -> Dict[str, List]:
# "chosen": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
# "rejected": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}]}
# That's why we check if the prompt is also conversational before deciding not to extract it.
if "prompt" in example and is_conversational({"prompt": example["prompt"]}):
return example
else:
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})
if "prompt" in example:
# Both conversational or both non-conversational
chosen_conv = is_conversational({"chosen": example["chosen"]})
prompt_conv = is_conversational({"prompt": example["prompt"]})
if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv):
return example
return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]})
Loading

0 comments on commit 78249d9

Please sign in to comment.