-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor DPO data processing #2209
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Regression reportI ran regression tests to ensure we don't break our DPO. Scenarios testedThe following scenarios were assessed for potential impact by recent changes:
Dataset SelectionAs discussed earlier, the new and old (
To avoid these cases, I used a conversational dataset with short content: def truncate(example):
return {
"prompt": [{"role": "user", "content": example["chosen"][0]["content"][:100]}],
"chosen": [{"role": "assistant", "content": example["chosen"][1]["content"][:100]}],
"rejected": [{"role": "assistant", "content": example["rejected"][1]["content"][:100]}],
}
dataset = dataset.map(truncate, desc="Truncate examples") Expected ChangesDifferences in log probabilities (logps) are expected due to initial miscalculations, as mentioned in my previous post. Encoder-decoderFor this one I needed a custom script: # dpo_encdec.py
import torch
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from trl import (
DPOConfig,
DPOScriptArguments,
DPOTrainer,
ModelConfig,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-v1_1-small")
peft_config = get_peft_config(model_config)
if peft_config is None:
ref_model = AutoModelForSeq2SeqLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
else:
ref_model = None
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
dataset = load_dataset(script_args.dataset_name)
def truncate(example):
return {
"prompt": [{"role": "user", "content": example["chosen"][0]["content"][:100]}],
"chosen": [{"role": "assistant", "content": example["chosen"][1]["content"][:100]}],
"rejected": [{"role": "assistant", "content": example["rejected"][1]["content"][:100]}],
}
dataset = dataset.map(truncate, desc="Truncate examples")
trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=peft_config,
)
trainer.train()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
![]() Decoder-only
![]() CommentNot sure exactly why the chosen and rejected don't match but the margin seems still to be very close Precompute reference
![]() CommentThe curves precisely match their corresponding run without Auxiliary lossmodify the example script and add model.config.output_router_logits = True
![]() CommentNot sure if the training helped a lot, but at least you have consistent results between main and #2209 We've a new Vision model |
…into refactor-dpo-data
Nice idea, I'll send you the checkpoints! |
Here is one:
Trained with
Another data point for the regression test: ![]() |
IFEvalThe new implementation seems to improve results
|
What does this PR do?
This PR refactors DPO data processing to be less intricate and easier to understand and maintain.
This main changes are the method
tokenize_row
and its equivalent for vision dataprocess_row
.This PR also solves a couple of issues:
It also modifies the following:
max_length
,max_prompt_length
andmax_completion_length
optionalRegression tests are presented in my next message.
Wrong truncation logic
Below is a simplified version of the current truncation logic (from
main
) used to handle sequences exceeding the allowedmax_length
:Issues with this logic
max_completion_length
handling: Whilemax_prompt_length
is applied, there's no equivalent handling for completions, making the logic even less intuitive.Example
Consider the following:
Applying the original truncation logic gives:
In this case, the total length ends up being 10, even though
max_length
allows for up to 15 tokens. This means we could have preserved more of the prompt.Proposed Solution
Here's a more flexible and intuitive approach (simplified for clarity):
Benefits
None
formax_length
,max_prompt_length
, andmax_completion_length
, offering more control over how each part is truncated.Result with the new logic
In this example, the prompt and completion are preserved within the allowed
max_length
without over-truncation.Wrong reported logits and log probs
The reported logits and log probs didn't take into account the padding. Eg:
The new reporting first applies the mask before computing the mean:
When merging prompt and completion alter tokenization
Related #1960 and #2054
(very rare but worth mentioning)
There might be a difference between tokenizing the full concatenated text versus tokenizing the prompt and completion separately, and then merging their token lists.
Consider the following example:
Tokenizing the combined prompt and completion produces different results than tokenizing each part separately and then concatenating their token lists:
This difference arises because we have a specific token for the string
".\n\n"
:In the first tokenization method, the
".\n\n"
is treated as a single token, whereas in the second method, the token is split between the prompt and the completion.In this example, the
main
code moves the merging token (382
) to the completion:While this approach is functional, it slightly alters the original text since the final period is no longer part of the prompt. To avoid this, we can split the tokens without altering the data:
This way, the period remains in the prompt, preserving the original structure of the text. Note that this scenario is quite rare, it typically does not occur in conversational datasets.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.