Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visual DPO #1647

Merged
merged 44 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
8768fe6
Remove extra whitespaces
qgallouedec May 17, 2024
5d43f2b
idefics
qgallouedec May 17, 2024
f5a3237
vdpo
qgallouedec May 27, 2024
682c034
sft idefics
qgallouedec May 27, 2024
bf01bf3
pad with test
qgallouedec May 30, 2024
aed1aeb
use prompt instead of tokenizer
qgallouedec May 30, 2024
e814f88
rm name main
qgallouedec May 30, 2024
fd5d71b
support vlm in tokenize row
qgallouedec May 30, 2024
e1b8755
temp fix for regex in lora_target_module
qgallouedec May 30, 2024
8075419
format
qgallouedec May 31, 2024
1b815c2
vdpo
qgallouedec May 31, 2024
6d6a194
tmp float16 hard code
qgallouedec Jun 3, 2024
1935d3d
concatenated_forward support for vision
qgallouedec Jun 3, 2024
bdc2b95
style and new command line
qgallouedec Jun 17, 2024
24b08f5
all-linear
qgallouedec Jun 17, 2024
c5ff8d7
format
qgallouedec Jun 18, 2024
a7d1732
delete old examples
qgallouedec Jun 18, 2024
2303c40
get image
qgallouedec Jun 18, 2024
b606190
upcast
qgallouedec Jun 18, 2024
4f78ee5
new test
qgallouedec Jun 18, 2024
c4433c0
modified test
qgallouedec Jun 18, 2024
7a8a94f
new strat for tokenizer
qgallouedec Jun 18, 2024
a9a4607
Merge branch 'main' into fix-vsft-example
qgallouedec Jun 25, 2024
9955710
rm token transfer
qgallouedec Jun 25, 2024
f6ee370
integrate vision in dpo example
qgallouedec Jun 25, 2024
56fb036
format
qgallouedec Jun 25, 2024
c3249e5
add FDivergenceType back
qgallouedec Jun 25, 2024
f69bb1c
precommit
qgallouedec Jun 25, 2024
6d859cf
pillow test dep
qgallouedec Jun 25, 2024
48db3e1
optional prompt
qgallouedec Jun 25, 2024
dea765b
`evaluation_strategy` to `eval_strategy`
qgallouedec Jun 25, 2024
d6dc3ba
revert vsft change (oos)
qgallouedec Jun 25, 2024
3a1f5b8
update test
qgallouedec Jun 25, 2024
5545825
test
qgallouedec Jun 25, 2024
5197d6d
comment and support more in process
qgallouedec Jun 26, 2024
45fda7e
update process
qgallouedec Jun 26, 2024
5a1dfa7
update doc for vdpo
qgallouedec Jun 26, 2024
2c10ca8
caution about limited support
qgallouedec Jun 26, 2024
2e47633
Update docs/source/dpo_trainer.mdx
qgallouedec Jun 26, 2024
f960a2a
revert DPO example changes
qgallouedec Jun 26, 2024
e4c7436
cleaner way to check if a model is vision
qgallouedec Jun 26, 2024
bfb35d3
comment
qgallouedec Jun 26, 2024
7b22153
update vdpo example
qgallouedec Jun 26, 2024
5155194
rename
qgallouedec Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,33 @@ dpo_dataset_dict = {

where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.

`DPOTrainer` can be used to fine-tune visual language models (VLMs). In this case, the dataset must also contain the key `images`. For example, for Idefics2, the processor expects the dataset to have the following format:
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved

Note: Currently, VLM support is exclusive to Idefics2 and does not extend to other VLMs.

```py
dpo_dataset_dict = {
'images': [
[Image.open('beach.jpg')],
[Image.open('street.jpg')],
],
'prompt': [
'The image <image> shows',
'<image> The image depicts',
],
'chosen': [
'a sunny beach with palm trees.',
'a busy street with several cars and buildings.',
],
'rejected': [
'a snowy mountain with skiers.',
'a calm countryside with green fields.',
],
}
```

## Expected model format
The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
The DPO trainer expects a model of `AutoModelForCausalLM` or `AutoModelForVision2Seq`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.

## Using the `DPOTrainer`

Expand All @@ -86,7 +111,7 @@ dpo_trainer = DPOTrainer(
model_ref,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
tokenizer=tokenizer, # for visual language models, use tokenizer=processor instead
)
```
After this one can then call:
Expand Down
76 changes: 63 additions & 13 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@
--use_peft \
--lora_r=16 \
--lora_alpha=16

# vision with peft:
accelerate launch examples/scripts/dpo.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path HuggingFaceM4/idefics2-8b \
--output_dir dpo_idefics_rlaif-v \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--learning_rate 1e-5 \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules=all-linear
"""

import logging
Expand All @@ -58,6 +71,7 @@
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser
from accelerate import PartialState

if TRL_USE_RICH:
init_zero_verbose()
Expand All @@ -68,7 +82,7 @@

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForVision2Seq, AutoProcessor

from trl import (
DPOConfig,
Expand Down Expand Up @@ -112,13 +126,25 @@
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
is_vision_model = model_config.model_name_or_path in ["HuggingFaceM4/idefics2-8b"]
if is_vision_model:
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
if is_vision_model:
model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = None
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if is_vision_model:
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=True)
tokenizer = processor.tokenizer
else:
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
Expand Down Expand Up @@ -148,16 +174,40 @@
ds[key] = ds[key].select(range(50))

def process(row):
row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
row["chosen"] = tokenizer.apply_chat_template([row["chosen"][-1]], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template([row["rejected"][-1]], tokenize=False)
# The prompt can be either a string or a list. In some datasets, the prompt is just a common string
# for both rejected and chosen (already included in chosen and rejected) and is not meant to be used
# separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen,
# and in such cases, it is properly formatted as a list with keys "role" and "content".
# Example 1:
# row = {"prompt": "What does detox mean?",
# "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "assistant"}]}
# Example 2:
# row = {"prompt": [{"content": "What does detox mean?", "role": "user"}],
# "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "I don't know.", "role": "assistant"}]}
if is_vision_model:
apply_chat_template = processor.apply_chat_template
else:
apply_chat_template = tokenizer.apply_chat_template

if "prompt" in row and isinstance(row["prompt"], list):
row["prompt"] = apply_chat_template(row["prompt"], tokenize=False)

row["chosen"] = apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = apply_chat_template(row["rejected"], tokenize=False)

if "images" in row:
for idx, img in enumerate(row["images"]): # Resize each image so the largest side is 640 pixels
ratio = min(1.0, 640 / max(img.size))
new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
row["images"][idx] = img.resize(new_size)
row["images"] = row["images"]

return row

ds = ds.map(
process,
num_proc=multiprocessing.cpu_count(),
load_from_cache_file=False,
)
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=multiprocessing.cpu_count())
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

Expand All @@ -171,7 +221,7 @@ def process(row):
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
tokenizer=processor if is_vision_model else tokenizer,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)
Expand Down
186 changes: 186 additions & 0 deletions examples/scripts/vdpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# flake8: noqa
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
accelerate launch examples/scripts/vdpo.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path HuggingFaceM4/idefics2-8b \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--learning_rate 1e-5 \
--logging_steps 5 \
--output_dir dpo_idefics_rlaif-v \
--push_to_hub --hub_model_id HuggingFaceH4/idefics2-8b-dpo-rlaif-v \
--bf16 \
--torch_dtype bfloat16 \
--logging_first_step \
--no_remove_unused_columns \
--dataset_num_proc 50 \
--dataload_num_workers 16 \
--use_peft \
--lora_target_modules=all-linear
"""

import logging
import multiprocessing
import os
from contextlib import nullcontext

TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)

from trl.commands.cli_utils import DPOScriptArguments, init_zero_verbose, TrlParser
from accelerate import PartialState

if TRL_USE_RICH:
init_zero_verbose()
FORMAT = "%(message)s"

from rich.console import Console
from rich.logging import RichHandler

import torch
from datasets import load_dataset
from transformers import AutoModelForVision2Seq, AutoProcessor

from trl import (
DPOConfig,
DPOTrainer,
ModelConfig,
RichProgressCallback,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)


if TRL_USE_RICH:
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)


if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
args, training_args, model_config = parser.parse_args_and_config()

# Force use our print callback
if TRL_USE_RICH:
training_args.disable_tqdm = True
console = Console()

################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)

model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
else:
model_ref = None
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False)
tokenizer = processor.tokenizer
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]

################
# Optional rich context managers
###############
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
save_context = (
nullcontext()
if not TRL_USE_RICH
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
)

################
# Dataset
################
ds = load_dataset(args.dataset_name)
if args.sanity_check:
for key in ds:
ds[key] = ds[key].select(range(50))

def process(row):
# The prompt can be either a string or a list. In some datasets, the prompt is just a common string
# for both rejected and chosen (already included in chosen and rejected) and is not meant to be used
# separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen,
# and in such cases, it is properly formatted as a list with keys "role" and "content".
# Example 1:
# row = {"prompt": "What does detox mean?",
# "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]}
# Example 2:
# row = {"prompt": [{"content": "What does detox mean?", "role": "user"}],
# "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "I don't know.", "role": "user"}]}
if "prompt" in row and isinstance(row["prompt"], list):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)

row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)

if "images" in row:
for idx, img in enumerate(row["images"]): # Resize each image so the largest side is 640 pixels
ratio = min(1.0, 640 / max(img.size))
new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio))
row["images"][idx] = img.resize(new_size)
row["images"] = row["images"]

return row

with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=multiprocessing.cpu_count())
train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

################
# Training
################
with init_context:
trainer = DPOTrainer(
model,
model_ref,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=processor,
peft_config=get_peft_config(model_config),
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
)

trainer.train()
trainer.push_to_hub
with save_context:
trainer.save_model(training_args.output_dir)
11 changes: 10 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,16 @@
"tyro>=0.5.11",
]
EXTRAS = {
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "pytest-cov", "pytest-xdist", "scikit-learn"],
"test": [
"parameterized",
"pytest",
"pytest-xdist",
"accelerate",
"pytest-cov",
"pytest-xdist",
"scikit-learn",
"Pillow",
],
"peft": ["peft>=0.4.0"],
"diffusers": ["diffusers>=0.18.0"],
"deepspeed": ["deepspeed>=0.9.5"],
Expand Down
Loading
Loading