Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

precompute_ref_log_probs not working correctly? #2423

Closed
7 of 9 tasks
dakru012 opened this issue Dec 2, 2024 · 6 comments
Closed
7 of 9 tasks

precompute_ref_log_probs not working correctly? #2423

dakru012 opened this issue Dec 2, 2024 · 6 comments
Labels
🏋 DPO Related to DPO

Comments

@dakru012
Copy link
Contributor

dakru012 commented Dec 2, 2024

System Info

  • Platform: Linux-6.8.0-49-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version: 2.5.1
  • CUDA device(s): NVIDIA GeForce RTX 4090
  • Transformers version: 4.46.2
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • Datasets version: 3.1.0
  • HF Hub version: 0.26.0
  • TRL version: 0.12.0
  • bitsandbytes version: 0.44.1
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.54.3
  • PEFT version: 0.13.2

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

# basic example dataset
train_dataset = Dataset.from_dict({
    "chosen": [
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
    ],
    "rejected": [
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
        [{"role": "system", "content": "Answer truthfully."}, {"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
    ],
})

train_dataset = train_dataset.map(extract_prompt)

training_args = DPOConfig(
    output_dir='DPO_output',
    logging_steps=10,
    loss_type='sigmoid',
    bf16=True,
    precompute_ref_log_probs=True,
)

trainer = DPOTrainer(
    model=model,
    ref_model= ref,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
)

trainer.train()

Expected behavior

Hi,
I have some questions about a potential issue or misunderstanding on my side.
The point of precompute_ref_log_probs is to calculate the ref log probabilities for the whole dataset before the actual training process, and then later during training we can just load the precomputed probabilities while saving the GPU memory space for the ref model, right?
However, it seems like the precomputed log probabilities are never actually loaded.

In the corresponding part in get_batch_loss_metrics():

def get_batch_loss_metrics(
        self,
        model,
        batch: Dict[str, Union[List, torch.LongTensor]],
        train_eval: Literal["train", "eval"] = "train",
    ):
        """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
        metrics = {}

        model_output = self.concatenated_forward(model, batch)

        # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model
        if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch:
            #LOADED
            ref_chosen_logps = batch["ref_chosen_logps"]
            ref_rejected_logps = batch["ref_rejected_logps"]
        else:
            #COMPUTED AGAIN
            ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch)
...

The if condition is never true, even if the log probabilities were computed, resulting in unnecessary computations for the ref model.
This is because the PreferenceCollator does not include the ref_chosen_logps and ref_rejected_logps in the batch.

I made some changes to the Collator to include those, but first I wanted to make sure that I understood the precompute_ref_log_probs argument correctly.

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@qgallouedec
Copy link
Member

That's a good catch, thanks @dakru012! Do you want to submit a PR to fix it?

@qgallouedec qgallouedec added the 🏋 DPO Related to DPO label Dec 2, 2024
@SwayamInSync
Copy link
Contributor

That's a good catch, thanks @dakru012! Do you want to submit a PR to fix it?

I think these lines within concatenated_forward are the culprit, names should be [ref_chosen_logps, ref_rejected_logps] instead of [chosen_logps, rejected_logps] then need to handle the same case at compute_ref_log_probs function

        output["chosen_logps"] = all_logps[:num_examples]
        output["rejected_logps"] = all_logps[num_examples:]

Let me know if the PR is there otherwise I can include the relevant fixes inside #2426 or made a new one

@dakru012
Copy link
Contributor Author

dakru012 commented Dec 2, 2024

@SwayamInSync I don't think that's the problem.
I will take a look at it again and do a PR, but it is already midnight here so I gotta sleep first 😴

@SwayamInSync
Copy link
Contributor

While looking on another issue in transformers library, I think this function was the problem of this issue

def _set_signature_columns_if_needed(self):
        # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
        # By default, this method sets `self._signature_columns` to the model's expected inputs.
        # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
        # Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override.
        if self._signature_columns is None:
            self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids", "image_sizes"]

It is used to remove the unused columns and since there is no ref columns are defined here so they all get removed, changing above to

def _set_signature_columns_if_needed(self):
        # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
        # By default, this method sets `self._signature_columns` to the model's expected inputs.
        # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
        # Instead, we set them to the columns expected by `DPODataCollatorWithPadding`, hence the override.
        if self._signature_columns is None:
            self._signature_columns = ["prompt_input_ids", "chosen_input_ids", "rejected_input_ids", "image_sizes", "ref_chosen_logps", "ref_rejected_logps"]

Fixed that condtion check

cc: @dakru012

@dakru012
Copy link
Contributor Author

dakru012 commented Dec 3, 2024

@SwayamInSync That is a good find, I overlooked that one and just set remove_unused_columns to False. I will test it and check the others trainers if there are similar problems.

I think there is also an small error in the data_collator description. It says that DPODataCollatorWithPadding is the default collator, but it seems to be PreferenceCollator now.

@SwayamInSync
Copy link
Contributor

@SwayamInSync That is a good find, I overlooked that one and just set remove_unused_columns to False. I will test it and check the others trainers if there are similar problems.

I think there is also an small error in the data_collator description. It says that DPODataCollatorWithPadding is the default collator, but it seems to be PreferenceCollator now.

Hey awesome and yes, the documentation about collator is misleading there, I would drop a quick fix to both in a PR later, please feel free to add any modifications needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🏋 DPO Related to DPO
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants