Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Padding free dpo #2437

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
ca99954
added eos token for ppotrainer
dame-cell Nov 30, 2024
fe1d5f6
remove the unnecessary stuff
dame-cell Nov 30, 2024
b15c635
Update ppo_config.py
dame-cell Nov 30, 2024
1bcb3a4
remove redundant EOS token fallback
dame-cell Dec 1, 2024
2ef2b24
remove redundant EOS token fallback
dame-cell Dec 1, 2024
42a0f73
remove some unnecessary tests stuff
dame-cell Dec 1, 2024
6130a91
added tests and update concatenated_inputs
dame-cell Dec 4, 2024
6732ed2
return only list and also a lot to do
dame-cell Dec 4, 2024
ce67292
padding free not tested but getting closer
dame-cell Dec 10, 2024
91e40aa
rebase and also reevaluate my approach
dame-cell Dec 10, 2024
8a34cb5
merge main
dame-cell Dec 10, 2024
1d38632
fix identation
dame-cell Dec 10, 2024
814d69e
better tests
dame-cell Dec 10, 2024
7dae607
concatenated_forward now supports padding_free
dame-cell Dec 10, 2024
1a59d74
collator now does not return attention masks
dame-cell Dec 10, 2024
562c52e
postion ids and no attention mask works
dame-cell Dec 11, 2024
d194054
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
3855851
update concatenated forward to support padding_free
dame-cell Dec 11, 2024
d9adbfb
Merge branch 'main' into padding_free_dpo
dame-cell Dec 11, 2024
1145006
grad accumalation tests
dame-cell Dec 13, 2024
24f73a4
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 13, 2024
f6bd9e1
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
bbd99cf
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
ba4969d
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
187b1e5
Resolved merge conflict in ppo_trainer.py
dame-cell Dec 13, 2024
1d9ce3e
fix identation
dame-cell Dec 13, 2024
8a974cc
comments update
dame-cell Dec 13, 2024
7f0298b
fix some small issue
dame-cell Dec 13, 2024
2900275
fix some small issue
dame-cell Dec 13, 2024
58e779a
fix some small issue
dame-cell Dec 13, 2024
6a1e251
fix some small issue
dame-cell Dec 13, 2024
f92e056
update concatenate_forward to support padding_fre
dame-cell Dec 13, 2024
457e3a1
fix some small issue
dame-cell Dec 13, 2024
f1789f4
fix some small issue
dame-cell Dec 13, 2024
0103728
So we need to make sure to correctlty handle the list
dame-cell Dec 13, 2024
5837faa
by correclty updatuing concatenated_forward it works now
dame-cell Dec 13, 2024
0321c1d
refactoring concatenated_forward and batched same length seq for padd…
dame-cell Dec 14, 2024
a6e2163
update
dame-cell Dec 14, 2024
1328fc3
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
51a2cc6
Merge branch 'main' into padding_free_dpo
dame-cell Dec 17, 2024
986ed71
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
570b79a
padding_free in concatenated_forward and update_test
dame-cell Dec 17, 2024
9dd9564
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 17, 2024
b781876
Merge branch 'main' into padding_free_dpo
dame-cell Dec 18, 2024
525ecb2
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
c8ce9c8
Merge branch 'main' into padding_free_dpo
dame-cell Dec 19, 2024
b7fad73
Reverted PPO trainer to original version and updated DPO files
dame-cell Dec 19, 2024
55cd219
Merge branch 'padding_free_dpo' of https://github.com/dame-cell/trl i…
dame-cell Dec 19, 2024
5e8df69
Updated DPO files
dame-cell Dec 19, 2024
ba1ded1
Merge branch 'main' into padding_free_dpo
dame-cell Dec 20, 2024
0784202
Merge branch 'main' into padding_free_dpo
dame-cell Dec 21, 2024
ddfed7c
update test_dpo_trainer.py
dame-cell Dec 21, 2024
955c7e8
update dpo_trainer.py
dame-cell Dec 21, 2024
ba4356e
update dpo_trainer.py
dame-cell Dec 21, 2024
c61abb5
update dpo_trainer.py
dame-cell Dec 21, 2024
64e9909
update dpo_trainer.py
dame-cell Dec 22, 2024
68186e7
Merge branch 'main' into padding_free_dpo
dame-cell Dec 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 4 additions & 56 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_dpo_trainer_padding_token_is_none(self):

trainer.train()

def test_dpo_trainer_w_dataset_num_proc(self):
dame-cell marked this conversation as resolved.
Show resolved Hide resolved
def test_dpo_trainer_padding_free_training(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
Expand All @@ -483,74 +483,22 @@ def test_dpo_trainer_w_dataset_num_proc(self):
learning_rate=9e-1,
eval_strategy="steps",
beta=0.1,
dataset_num_proc=5,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

tokenizer = AutoTokenizer.from_pretrained(self.model_id)
tokenizer.pad_token = None

with self.assertRaisesRegex(
ValueError,
expected_regex=r"Can't find `pad_token_id` in the `processing_class`. "
r"Explicitly set `tokenizer.pad_token` \(e.g. `tokenizer.pad_token = tokenizer.eos_token`\) "
r"before instantiating the trainer.",
):
trainer = DPOTrainer(
model=self.model,
ref_model=None,
args=training_args,
processing_class=tokenizer,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

trainer.train()

def test_tr_dpo_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to have tests for this with gradient accumulation too. perhaps using pytest.mark.parameterize?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right will do so thanks for reviewing 😎

learning_rate=9e-1,
eval_strategy="steps",
precompute_ref_log_probs=False,
sync_ref_model=True,
ref_model_mixup_alpha=0.5,
ref_model_sync_steps=1,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

trainer = DPOTrainer(
model=self.model,
ref_model=self.ref_model,
ref_model=None,
args=training_args,
processing_class=self.tokenizer,
tokenizer=self.tokenizer,
padding_free=True,
train_dataset=dummy_dataset["train"],
eval_dataset=dummy_dataset["test"],
)

# params of the ref model as its the same as the model
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.ref_model.get_parameter(n)
# check the ref model's params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class DPOConfig(TrainingArguments):
for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios
when working with very long prompts where labels are -ignored (-100).
[Read more](https://huggingface.co/docs/transformers/main/model_doc/llama#transformers.LlamaForCausalLM)
padding_free (`bool`, defaults to `False`):
Whether to use padding-free training. If set to `True`, the trainer will operate in a padding-free mode.
"""

learning_rate: float = 1e-6
Expand Down Expand Up @@ -192,3 +194,4 @@ class DPOConfig(TrainingArguments):
rpo_alpha: Optional[float] = None
discopop_tau: float = 0.05
use_num_logits_to_keep: bool = False
padding_free: bool = False
145 changes: 96 additions & 49 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class DPOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.

"""

_tag_names = ["trl", "dpo"]
Expand Down Expand Up @@ -347,7 +348,7 @@ def make_inputs_require_grad(module, input, output):
self.model_adapter_name = args.model_adapter_name
self.ref_adapter_name = args.ref_adapter_name
self.reference_free = args.reference_free

self.padding_free = args.padding_free
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model or args.precompute_ref_log_probs:
Expand Down Expand Up @@ -1068,8 +1069,10 @@ def dpo_loss(

def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

We do this to avoid doing two forward passes, because it's faster for FSDP.
When padding_free=True, sequences are concatenated without padding tokens and processed as a single
continuous sequence with reset position IDs, improving memory efficiency and computation speed.
When False, uses standard padded batch processing.
"""
num_examples = batch["prompt_input_ids"].shape[0]

Expand Down Expand Up @@ -1121,55 +1124,99 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
attention_mask[i] = torch.roll(attention_mask[i], shifts=-first_one_idx)
loss_mask[i] = torch.roll(loss_mask[i], shifts=-first_one_idx)

# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(1)
input_ids = input_ids[:, :first_empty_col]
attention_mask = attention_mask[:, :first_empty_col]
loss_mask = loss_mask[:, :first_empty_col]
if self.padding_free:
# Pre-calculate sequence lengths
seq_lengths = attention_mask.sum(1)

# Truncate right
if self.args.max_length is not None:
# we should apply it to individual sequence lengths before concatenation
if self.args.max_length is not None:
seq_lengths = torch.clamp(seq_lengths, max=self.args.max_length)

# truncate the input_ids as well based on the input_ids
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
loss_mask = loss_mask[:, : self.args.max_length]

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
# [[0, 0, 0, x, x, x, x],
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min()
num_logits_to_keep = loss_mask.shape[1] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label

outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)

# Offset the logits by one to align with the labels
logits = outputs.logits[:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()

if self.use_num_logits_to_keep:
# Align labels with logits
# logits: -, -, [x2, x3, x4, x5, x6]
# ^ --------- ^ after logits[:, :-1, :]
# labels: [y0, y1, y2, y3, y4, y5, y6]
# ^ --------- ^ with num_logits_to_keep=4, [:, -4:]
# loss_mask: [0, 0, 0, 1, 1, 1, 1]
labels = labels[:, -num_logits_to_keep:]
loss_mask = loss_mask[:, -num_logits_to_keep:]

if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]

# Compute the log probabilities of the labels
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)
total_length = seq_lengths.sum().item()

# Pre-allocate tensors
concatenated_input_ids = torch.zeros(total_length, dtype=input_ids.dtype, device=input_ids.device)

# Fill tensors
current_idx = 0
sequence_boundaries = []
for i in range(input_ids.size(0)):
length = seq_lengths[i].item()
valid_tokens = input_ids[i, :length]
concatenated_input_ids[current_idx : current_idx + length] = valid_tokens
sequence_boundaries.append((current_idx, current_idx + length))
current_idx += length

# Create position ids
position_ids = torch.arange(total_length, device=input_ids.device)

# remove attention mask
model_kwargs.pop("attention_mask", None)

outputs = model(
input_ids=concatenated_input_ids.unsqueeze(0),
position_ids=position_ids.unsqueeze(0),
**model_kwargs,
)

# Process outputs
logits = outputs.logits[0, :-1, :]
labels = concatenated_input_ids[1:].clone()

# Calculate per-token log probabilities
per_token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)

# Split sequences back into batch
start_idx = 0
batch_logps = []
for length in seq_lengths:
sequence_logps = per_token_logps[start_idx : start_idx + length - 1]
batch_logps.append(sequence_logps.sum())
start_idx += length

all_logps = torch.stack(batch_logps)

else:
# Truncate right
if self.args.max_length is not None:
input_ids = input_ids[: self.args.max_length]
attention_mask = attention_mask[: self.args.max_length]
loss_mask = loss_mask[: self.args.max_length]

if self.use_num_logits_to_keep:
# Compute num_logits_to_keep based on loss_mask pattern:
# [[0, 0, 0, x, x, x, x],
# [0, 0, 0, x, x, x, 0]]
# ^ start computing logits from here ([:, -(7-3+1):])
first_compute_index = loss_mask.nonzero(as_tuple=True)[0].min()
num_logits_to_keep = loss_mask.shape[0] - first_compute_index
model_kwargs["num_logits_to_keep"] = num_logits_to_keep.item() + 1 # +1 for the first label

# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(attention_mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else attention_mask.size(0)
input_ids = input_ids[:first_empty_col]
attention_mask = attention_mask[:first_empty_col]
loss_mask = loss_mask[:first_empty_col]

outputs = model(input_ids=input_ids, attention_mask=attention_mask, **model_kwargs)

# Offset the logits by one to align with the labels
logits = outputs.logits[:-1, :]
labels = input_ids[1:].clone()
loss_mask = loss_mask[1:]

if logits.shape[:2] != labels.shape[:2]:
# for llava, the returned logits include the image tokens (placed before the text tokens)
seq_len = labels.shape[1]
logits = logits[:, -seq_len:]

labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
all_logps = per_token_logps.sum(-1)

output = {}

Expand Down
1 change: 1 addition & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ class PPOConfig(OnPolicyConfig):
cliprange: float = 0.2
vf_coef: float = 0.1
cliprange_value: float = 0.2
"""Clip range for the value function."""
gamma: float = 1.0
lam: float = 0.95
3 changes: 1 addition & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -142,7 +142,6 @@ def __init__(
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding

# peft support
if not is_peft_available() and peft_config is not None:
raise ImportError(
Expand Down