Skip to content

Commit

Permalink
🏁 Add bos_token_id only if it exists (#2279)
Browse files Browse the repository at this point in the history
Co-authored-by: sean.jung <sean.jung@sean-ai.local>
  • Loading branch information
seanexp and sean.jung authored Oct 25, 2024
1 parent 57ba9b9 commit 110d088
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
25 changes: 14 additions & 11 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **
)

# add BOS, which affects both prompt and the full completion
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
if bos_token_id is not None:
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}prompt_attention_mask"
]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
# add EOS, which affects only the full completion
if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
Expand Down
25 changes: 14 additions & 11 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **
)

# add BOS, which affects both prompt and the full completion
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
if bos_token_id is not None:
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}prompt_attention_mask"
]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
# add EOS, which affects only the full completion
if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
Expand Down

0 comments on commit 110d088

Please sign in to comment.