Skip to content

Commit

Permalink
Visual DPO (huggingface#1647)
Browse files Browse the repository at this point in the history
* Remove extra whitespaces

* idefics

* vdpo

* sft idefics

* pad with test

* use prompt instead of tokenizer

* rm name main

* support vlm in tokenize row

* temp fix for regex in lora_target_module

* format

* vdpo

* tmp float16 hard code

* concatenated_forward support for vision

* style and new command line

* all-linear

* format

* delete old examples

* get image

* upcast

* new test

* modified test

* new strat for tokenizer

* rm token transfer

* integrate vision in dpo example

* format

* add FDivergenceType back

* precommit

* pillow test dep

* optional prompt

* `evaluation_strategy` to `eval_strategy`

* revert vsft change (oos)

* update test

* test

* comment and support more in process

* update process

* update doc for vdpo

* caution about limited support

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* revert DPO example changes

* cleaner way to check if a model is vision

* comment

* update vdpo example

* rename

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent c8c01cc commit b68ff96
Show file tree
Hide file tree
Showing 8 changed files with 534 additions and 44 deletions.
29 changes: 27 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,33 @@ dpo_dataset_dict = {

where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.

`DPOTrainer` can be used to fine-tune visual language models (VLMs). In this case, the dataset must also contain the key `images`, and the trainer's `tokenizer` is the VLM's `processor`. For example, for Idefics2, the processor expects the dataset to have the following format:

Note: Currently, VLM support is exclusive to Idefics2 and does not extend to other VLMs.

```py
dpo_dataset_dict = {
'images': [
[Image.open('beach.jpg')],
[Image.open('street.jpg')],
],
'prompt': [
'The image <image> shows',
'<image> The image depicts',
],
'chosen': [
'a sunny beach with palm trees.',
'a busy street with several cars and buildings.',
],
'rejected': [
'a snowy mountain with skiers.',
'a calm countryside with green fields.',
],
}
```

## Expected model format
The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
The DPO trainer expects a model of `AutoModelForCausalLM` or `AutoModelForVision2Seq`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.

## Using the `DPOTrainer`

Expand All @@ -86,7 +111,7 @@ dpo_trainer = DPOTrainer(
model_ref,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
tokenizer=tokenizer, # for visual language models, use tokenizer=processor instead
)
```
After this one can then call:
Expand Down
177 changes: 177 additions & 0 deletions examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
accelerate launch examples/scripts/vdpo.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path HuggingFaceM4/idefics2-8b \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--dataset_num_proc 32 \
--output_dir dpo_idefics_rlaif-v \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules=all-linear
"""

import logging
import os
from contextlib import nullcontext

TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser
from accelerate import PartialState

if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"

from rich.console import Console
from rich.logging import RichHandler

import torch
from datasets import load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor

from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)


if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)


if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()

# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()

################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)

model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = None
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False)
tokenizer = processor.tokenizer
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)

################
# Dataset
################
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))

def process(row):
# The prompt can be either a string or a list. In some datasets, the prompt is just a common string
# for both rejected and chosen (already included in chosen and rejected) and is not meant to be used
# separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen,
# and in such cases, it is properly formatted as a list with keys "role" and "content".
# Example 1:
# row = {"prompt": "What does detox mean?",
# "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]}
# Example 2:
# row = {"prompt": [{"content": "What does detox mean?", "role": "user"}],
# "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "I don't know.", "role": "user"}]}
if "prompt" in row and isinstance(row["prompt"], list):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)

row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)

if "images" in row:
for idx, img in enumerate(row["images"]): # Resize each image so the largest side is 640 pixels
ratio = min(1.0, 640 / max(img.size))
new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
row["images"][idx] = img.resize(new_size)
row["images"] = row["images"]

return row

with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

################
# Training
################
with init_context:
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=processor,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)

trainer.train()
trainer.push_to_hub
with save_context:
trainer.save_model(training_args.output_dir)
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,16 @@
"tyro>=0.5.11",
]
EXTRAS = {
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist", "scikit-learn"],
"test": [
"parameterized",
"pytest",
"pytest-xdist",
"accelerate",
"pytest-cov",
"pytest-xdist",
"scikit-learn",
"Pillow",
],
"peft": ["peft>=0.4.0"],
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
Expand Down
120 changes: 118 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@
import unittest

import torch
from datasets import Dataset
from datasets import Dataset, features
from parameterized import parameterized
from PIL import Image
from pytest import mark
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
)

from trl import DPOConfig, DPOTrainer, FDivergenceType

Expand All @@ -40,6 +47,12 @@ def setUpClass(cls):
cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)

# get idefics2 model
model_id = "trl-internal-testing/tiny-random-idefics2"
cls.idefics2_model = AutoModelForVision2Seq.from_pretrained(model_id)
cls.idefics2_ref_model = AutoModelForVision2Seq.from_pretrained(model_id)
cls.idefics2_processor = AutoProcessor.from_pretrained(model_id)

def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
Expand Down Expand Up @@ -80,6 +93,57 @@ def _init_dummy_dataset(self):
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)

def _init_dummy_image_dataset(self):
# fmt: off
dummy_dataset_dict = {
"images": [
[Image.new("RGB", (100, 50), color="black")],
# None,
# [Image.new("RGB", (100, 100), color="blue"), Image.new("RGB", (150, 50), color="red")],
[Image.new("RGB", (200, 100), color="green")],
# [Image.new("RGB", (150, 150), color="yellow"), Image.new("RGB", (50, 150), color="purple")],
[Image.new("RGB", (80, 120), color="gray")],
[Image.new("RGB", (120, 80), color="pink")],
],
"prompt": [
"<image> Hello",
# "How are you?",
# "<image><image> Let's chat",
"<image> Good morning",
# "<image><image> What's up?",
"Can you see this? <image>",
"Here is something interesting: <image>",
],
"chosen": [
"Hi nice to meet you!",
# "I'm doing well, thank you!",
# "Sure, let's talk!",
"Good morning to you too!",
# "Not much, just working.",
"Yes, I can see it clearly.",
"That's quite interesting indeed.",
],
"rejected": [
"Leave me alone!",
# "I'm not interested.",
# "I don't want to chat.",
"I'm still sleepy.",
# "Busy right now, talk later.",
"No, I can't see it.",
"I'm not sure what that is.",
],
}
# fmt: on
f = features.Features(
{
"images": features.Sequence(features.Image(decode=True)), # datasets handles badly sequence of images
"prompt": features.Value("string"),
"chosen": features.Value("string"),
"rejected": features.Value("string"),
}
)
return Dataset.from_dict(dummy_dataset_dict, features=f)

@parameterized.expand(
[
["gpt2", "sigmoid", True],
Expand Down Expand Up @@ -152,6 +216,54 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

@parameterized.expand(
[
["sigmoid", True],
]
)
def test_vdpo_trainer(self, loss_type, pre_compute):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
loss_type=loss_type,
precompute_ref_log_probs=pre_compute,
)

dummy_dataset = self._init_dummy_image_dataset()

model = self.idefics2_model
ref_model = self.idefics2_ref_model
processor = self.idefics2_processor

trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
tokenizer=processor,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)

def test_dpo_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
Expand Down Expand Up @@ -811,3 +923,7 @@ def test_dpo_loss_js_div_f(self):
policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps
)
assert torch.isfinite(losses).cpu().numpy().all()


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit b68ff96

Please sign in to comment.