Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor DPO data processing #2209

Merged
merged 72 commits into from
Oct 21, 2024
Merged

Refactor DPO data processing #2209

merged 72 commits into from
Oct 21, 2024

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Oct 9, 2024

What does this PR do?

This PR refactors DPO data processing to be less intricate and easier to understand and maintain.

This main changes are the method tokenize_row and its equivalent for vision data process_row.

This PR also solves a couple of issues:

  • Wrong reported logits
  • Wrong truncation logic

It also modifies the following:

  • Make max_length, max_prompt_length and max_completion_length optional
  • When merging prompt and completion alter tokenization

Regression tests are presented in my next message.

Wrong truncation logic

Below is a simplified version of the current truncation logic (from main) used to handle sequences exceeding the allowed max_length:

max_length = 10  # example value
max_prompt_length = 5  # example value

# if the combined sequence is too long, truncate the prompt
if len(prompt) + len(completion) > max_length:
    prompt = prompt[-max_prompt_length:]

# if it's still too long, truncate the completion from the end
if len(prompt) + len(completion) > max_length:
    completion = completion[: max_length - max_prompt_length]

Issues with this logic

  1. Unnecessary prompt truncation: The prompt gets truncated more than necessary.
  2. Non-intuitive logic: It's not immediately clear how truncation is applied.
  3. Lack of max_completion_length handling: While max_prompt_length is applied, there's no equivalent handling for completions, making the logic even less intuitive.

Example

Consider the following:

prompt = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
completion = [11, 12, 13, 14, 15]

Applying the original truncation logic gives:

>>> prompt
[6, 7, 8, 9, 10]
>>> completion
[11, 12, 13, 14, 15]
>>> len(prompt) + len(completion)
10

In this case, the total length ends up being 10, even though max_length allows for up to 15 tokens. This means we could have preserved more of the prompt.

Proposed Solution

Here's a more flexible and intuitive approach (simplified for clarity):

max_length = 10  # example value
max_prompt_length = None  # example value
max_completion_length = 8  # example value

# Truncate prompt if max_prompt_length is defined
if max_prompt_length is not None:
    prompt = prompt[-max_prompt_length:]

# Truncate completion if max_completion_length is defined
if max_completion_length is not None:
    completion = completion[:max_completion_length]

# Ensure combined prompt and completion don't exceed max_length
if len(prompt) + len(completion) > max_length:
    completion = completion[:max_length - len(prompt)]

Benefits

  1. Retains more of the prompt: This logic avoids unnecessarily truncating the prompt.
  2. More intuitive: It's clearer when each truncation step happens, and the flow is easier to follow.
  3. Supports flexible lengths: Allows None for max_length, max_prompt_length, and max_completion_length, offering more control over how each part is truncated.

Result with the new logic

>>> prompt, completion
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12, 13, 14, 15])
>>> len(prompt) + len(completion)
15

In this example, the prompt and completion are preserved within the allowed max_length without over-truncation.

Wrong reported logits and log probs

The reported logits and log probs didn't take into account the padding. Eg:

chosen_logits = logits.mean()

The new reporting first applies the mask before computing the mean:

chosen_logits = logits[loss_mask].mean()

When merging prompt and completion alter tokenization

Related #1960 and #2054

(very rare but worth mentioning)

There might be a difference between tokenizing the full concatenated text versus tokenizing the prompt and completion separately, and then merging their token lists.

Consider the following example:

prompt = "Lorem."
completion = "\n\nipsum"

Tokenizing the combined prompt and completion produces different results than tokenizing each part separately and then concatenating their token lists:

>>> tokenizer(prompt + completion)["input_ids"]
[32783, 382, 573, 1242]
>>> tokenizer(prompt)["input_ids"]
[32783, 13]
>>> tokenizer(completion)["input_ids"]
[271, 573, 1242]

This difference arises because we have a specific token for the string ".\n\n":

>>> tokenizer.decode([382])
'.\n\n'
>>> tokenizer.decode([13, 271])
'.\n\n'

In the first tokenization method, the ".\n\n" is treated as a single token, whereas in the second method, the token is split between the prompt and the completion.

In this example, the main code moves the merging token (382) to the completion:

>>> prompt_ids = [32783]
>>> completion_ids = [382, 573, 1242]
>>> tokenizer.decode(prompt_ids)
'Lorem'
>>> tokenizer.decode(completion_ids)
'.\n\nipsum'

While this approach is functional, it slightly alters the original text since the final period is no longer part of the prompt. To avoid this, we can split the tokens without altering the data:

>>> prompt_ids = [32783, 13]
>>> completion_ids = [271, 573, 1242]
>>> tokenizer.decode(prompt_ids)
'Lorem.'
>>> tokenizer.decode(completion_ids)
'\n\nipsum'

This way, the period remains in the prompt, preserving the original structure of the text. Note that this scenario is quite rare, it typically does not occur in conversational datasets.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member Author

qgallouedec commented Oct 15, 2024

Regression report

I ran regression tests to ensure we don't break our DPO.

Scenarios tested

The following scenarios were assessed for potential impact by recent changes:

  • Encoder-decoder model
  • Decoder-only
  • Precompute ref
  • Auxiliary loss
  • Vision models

Dataset Selection

As discussed earlier, the new and old (main) implementations are not equivalent in cases involving:

  • Merging of prompt and completion leading to token merging
  • Truncation needed

To avoid these cases, I used a conversational dataset with short content: trl-lib/ultrafeedback_binarized. I applied the following truncation preprocessing to limit sequence length:

def truncate(example):
    return {
        "prompt": [{"role": "user", "content": example["chosen"][0]["content"][:100]}],
        "chosen": [{"role": "assistant", "content": example["chosen"][1]["content"][:100]}],
        "rejected": [{"role": "assistant", "content": example["rejected"][1]["content"][:100]}],
    }

dataset = dataset.map(truncate, desc="Truncate examples")

Expected Changes

Differences in log probabilities (logps) are expected due to initial miscalculations, as mentioned in my previous post.

Encoder-decoder

For this one I needed a custom script:

# dpo_encdec.py
import torch
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from trl import (
    DPOConfig,
    DPOScriptArguments,
    DPOTrainer,
    ModelConfig,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


if __name__ == "__main__":
    parser = TrlParser((DPOScriptArguments, DPOConfig, ModelConfig))
    script_args, training_args, model_config = parser.parse_args_and_config()
    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
    )
    model = AutoModelForSeq2SeqLM.from_pretrained("google/t5-v1_1-small")
    peft_config = get_peft_config(model_config)
    if peft_config is None:
        ref_model = AutoModelForSeq2SeqLM.from_pretrained(
            model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
        )
    else:
        ref_model = None
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.chat_template is None:
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

    dataset = load_dataset(script_args.dataset_name)

    def truncate(example):
        return {
            "prompt": [{"role": "user", "content": example["chosen"][0]["content"][:100]}],
            "chosen": [{"role": "assistant", "content": example["chosen"][1]["content"][:100]}],
            "rejected": [{"role": "assistant", "content": example["rejected"][1]["content"][:100]}],
        }

    dataset = dataset.map(truncate, desc="Truncate examples")

    trainer = DPOTrainer(
        model,
        ref_model,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split],
        processing_class=tokenizer,
        peft_config=peft_config,
    )

    trainer.train()
    metrics = trainer.evaluate()
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)
# 8 GPUs
accelerate launch dpo_encdec.py \
    --dataset_name trl-lib/ultrafeedback_binarized \
    --model_name_or_path google/t5-v1_1-small \
    --learning_rate 5.0e-7 \
    --num_train_epochs 1 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 100 \
    --output_dir t5-v1_1-DPO-main \
    --no_remove_unused_columns
Screenshot 2024-10-16 at 18 06 03

Decoder-only

# 8 GPUs
accelerate launch examples/scripts/dpo.py \
    --dataset_name trl-lib/ultrafeedback_binarized \
    --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
    --learning_rate 5.0e-7 \
    --num_train_epochs 1 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 100 \
    --output_dir Qwen2-0.5B-DPO-main \
    --no_remove_unused_columns
Screenshot 2024-10-16 at 17 52 25

Comment

Not sure exactly why the chosen and rejected don't match but the margin seems still to be very close

Precompute reference

# 8 GPUs
accelerate launch examples/scripts/dpo.py \
    --dataset_name trl-lib/ultrafeedback_binarized \
    --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
    --learning_rate 5.0e-7 \
    --num_train_epochs 1 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 100 \
    --output_dir Qwen2-0.5B-DPO-main \
    --no_remove_unused_columns \
    --precompute_ref_log_probs
Screenshot 2024-10-16 at 18 29 21

Comment

The curves precisely match their corresponding run without --precompute_ref_log_probs.

Auxiliary loss

modify the example script and add

model.config.output_router_logits = True
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/dpo.py \
    --dataset_name trl-lib/ultrafeedback_binarized \
    --model_name_or_path mistralai/Mixtral-8x7B-v0.1 \
    --learning_rate 5.0e-7 \
    --num_train_epochs 1 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 100 \
    --output_dir Qwen2-0.5B-DPO-2209 \
    --gradient_checkpointing \
    --max_length 256 \
    --use_peft \
    --bf16
Screenshot 2024-10-17 at 17 07 18

Comment

Not sure if the training helped a lot, but at least you have consistent results between main and #2209

We've a new aux_loss plot!

Vision model

@kashif kashif self-requested a review October 15, 2024 11:21
@qgallouedec
Copy link
Member Author

Regarding the difference in the chosen/rejected rewards of your regression tests, have you looked at the impact on downstream evals like IFEval / AlpacaEval / MixEval? I can run those for you if you have the checkpoints handy and then we can be pretty sure it's fine

Nice idea, I'll send you the checkpoints!

@qgallouedec
Copy link
Member Author

qgallouedec commented Oct 17, 2024

@lewtun

Here is one:

Trained with

accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/dpo.py \
    --dataset_name trl-lib/ultrafeedback_binarized \
    --model_name_or_path Qwen/Qwen2.5-7B-Instruct \
    --learning_rate 5.0e-7 \
    --num_train_epochs 1 \
    --gradient_checkpointing \
    --logging_steps 10 \
    --eval_strategy steps \
    --eval_steps 100 \
    --output_dir Qwen2.5-7B-DPO-2209 \
    --gradient_checkpointing \
    --max_length 512 \
    --use_peft \
    --bf16 \
    --push_to_hub

Another data point for the regression test:

Screenshot 2024-10-18 at 00 10 32

@qgallouedec
Copy link
Member Author

IFEval

The new implementation seems to improve results

Model inst_level_loose_acc inst_level_strict_acc prompt_level_loose_acc prompt_level_strict_acc
Qwen2.5-7B 0.7122 0.6631 0.6026 ± 0.0211 0.5416 ± 0.0214
Qwen2.5-7B-DPO-main 0.7182 0.6751 0.6155 ± 0.0209 0.5693 ± 0.0213
Qwen2.5-7B-DPO-2209 0.7326 0.6775 0.6303 ± 0.0208 0.5656 ± 0.0213

@qgallouedec qgallouedec merged commit 92f6d24 into main Oct 21, 2024
10 checks passed
@qgallouedec qgallouedec deleted the refactor-dpo-data branch October 21, 2024 10:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants