diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index 517da43f55..f1e9bcb4d8 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -119,6 +119,8 @@ class SetupChatFormatTestCase(unittest.TestCase): def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + # remove built-in chat_template to simulate a model having no chat_template + self.tokenizer.chat_template = None def test_setup_chat_format(self): original_tokenizer_len = len(self.tokenizer) diff --git a/trl/models/utils.py b/trl/models/utils.py index afdc944154..562b8617ed 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -84,6 +84,8 @@ def setup_chat_format( """ Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. + If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`. + Args: model (`~transformers.PreTrainedModel`): The model to be modified. tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. @@ -94,6 +96,12 @@ def setup_chat_format( model (`~transformers.PreTrainedModel`): The modified model. tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. """ + # check if model already had a chat template + if tokenizer.chat_template is not None: + raise ValueError( + "Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None" + ) + # check if format available and retrieve if format not in FORMAT_MAPPING: raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")