From 110d0884c75f1cbb3ae8e041bce8436fa7596054 Mon Sep 17 00:00:00 2001 From: Seungjae Jung Date: Sat, 26 Oct 2024 01:15:08 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=8F=81=20Add=20`bos=5Ftoken=5Fid`=20only?= =?UTF-8?q?=20if=20it=20exists=20(#2279)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: sean.jung --- trl/trainer/bco_trainer.py | 25 ++++++++++++++----------- trl/trainer/kto_trainer.py | 25 ++++++++++++++----------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index c6ce2d4902..baf43a42bb 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -223,17 +223,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, ** ) # add BOS, which affects both prompt and the full completion - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}prompt_attention_mask" + ] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] # add EOS, which affects only the full completion if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index fa1541bb44..e6991182a7 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -218,17 +218,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, ** ) # add BOS, which affects both prompt and the full completion - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}prompt_attention_mask" + ] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] # add EOS, which affects only the full completion if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [