Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Empty prompts crash in KTOTrainer #2087

Closed
2 of 4 tasks
gabikadlecova opened this issue Sep 19, 2024 · 1 comment · Fixed by #2093
Closed
2 of 4 tasks

Empty prompts crash in KTOTrainer #2087

gabikadlecova opened this issue Sep 19, 2024 · 1 comment · Fixed by #2093
Labels
🐛 bug Something isn't working

Comments

@gabikadlecova
Copy link
Contributor

System Info

Hello,
I discovered a possible bug while doing a review of a student's thesis. In trl/trainer/kto_trainer.py, the function _process_tokens crashes if the prompt is empty (i.e. zero length prompt_input_ids.

I am not that familiar with the details of KTO - is an empty prompt a valid input? I guess empty completion is not, but I can imagine that fine-tuning on good output data without a specific prompt might be a valid use case.

The fix is a one-liner, I can draft a PR, I just wanted to discuss here if it should be fixed to allow empty prompts.

TRL version: newest (current main version)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Minimum example how to reproduce it:

from trl.trainer.kto_trainer import _process_tokens
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-coder-1.3b-base')  # or any other

# some values not related to the bug are just random (nonsensical) values (e.g. attention_mask)
_process_tokens({'prompt': '', 'completion': '', 'answer_input_ids': [100], 'prompt_input_ids': [],
    'prompt_attention_mask': [''], 'answer_attention_mask': [''], 'label': 1},
    prefix='', is_encoder_decoder=False, tokenizer= tokenizer, max_completion_length=100, max_prompt_length=100, max_length=100, label_pad_token_id=100)

Exception:

File .../trl_git/trl/trainer/kto_trainer.py:186, in _process_tokens(example, model, **kwargs)
    183     print(example)
    184     print(example["completion"])
--> 186 if bos_token_id != all_tokens["prompt_input_ids"][0]:
    187     max_length -= 1
    188 if eos_token_id != all_tokens["answer_input_ids"][-1]:

IndexError: list index out of range

Same goes for answer_input_ids, but I guess that case is not valid.

Expected behavior

Output after fixing the condition at the crash point:

{'prompt': '',
 'completion': '',
 'label': 1,
 'prompt_input_ids': [32013],
 'prompt_attention_mask': [1, ''],
 'completion_input_ids': [32013, 100, 32014],
 'completion_attention_mask': [1, '', '', 1],
 'completion_labels': [100, 100, 32014]}
@gabikadlecova gabikadlecova added the 🐛 bug Something isn't working label Sep 19, 2024
@kashif
Copy link
Collaborator

kashif commented Sep 20, 2024

yes let's fix it just in case as its an edge case as far as i can tell so at least it should not crash... happy to have your PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants