From bfc77b0f3628c8df43f974873344124b8c947c26 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 29 Oct 2024 10:14:51 +0700 Subject: [PATCH] =?UTF-8?q?Feat:=20Add=20support=20for=20tokenizer?= =?UTF-8?q?=E2=80=99s=20or=20custom=20jinja=20chat=5Ftemplate=20(#1970)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow using tokenizer's default chat template with fallbacks Summary of changes: 1. Adds `tokenizer_default` as option for `chat_template` in `chat_template` prompt strategy that allows using the chat template from tokenizer's config.json 2. Allows falling back to chat templates available in axolotl if tokenizer does not have a chat template 3. Adds a mistral chat template which supports system message - taken from https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja --- Why? Many popular models are not trained with chatml format. As a result for the model to correctly learn chatml we have to turn on train_on_inputs which requires more compute and time. If we can use the model's already learned chat template we can just learn the output tokens --- Todo: - Write tests * Add tests * Fix lint and bug post merge from main * Add option `chat_template_jinja` to provide a jinja template * remove custom mistral template * Address review comments and add docs * Update docs/dataset-formats/conversation.qmd Co-authored-by: NanoCode012 * fix: set default to tokenizer template * Merge branch 'main' into cj_tokenizer_default_prompt_template * chore: remove redundant function * fix: re-arrange enum declaration position * fix: refactor artifact left from main merge * feat(doc): updated config with chat template options and clarified examples * chore: clarify doc * chore: added example for non-default template * chore: refactor * fix: test * fix: config being dropped and unittest to catch that * chore: lint * chore: skip duplicate * fix: rename var after merge * feat: add test for levy's dpo case * fix: remove default setting on edge case where chat template overriden in dataset section * feat: handle sharegpt deprecation better in docs * feat: add example using fallback * feat: handles chat_template requiring specific user/assistant order * fix: update test based on new defaults * fix: imported name incorrectly updated on merge * chore: lint * fix: update dummy message to prevent potential overlap with real content * fix(doc): formatting * fix: update bradleyterry to use new chat_template --------- Co-authored-by: Chirag Jain --- README.md | 2 +- docs/config.qmd | 57 ++++- docs/dataset-formats/conversation.qmd | 137 ++++++++++ src/axolotl/cli/__init__.py | 4 +- src/axolotl/core/trainer_builder.py | 4 +- .../bradley_terry/__init__.py | 2 +- .../bradley_terry/chat_template.py | 42 ++-- .../prompt_strategies/chat_template.py | 8 +- .../prompt_strategies/dpo/chat_template.py | 24 +- .../prompt_strategies/orpo/chat_template.py | 29 +-- src/axolotl/prompt_strategies/sharegpt.py | 2 +- src/axolotl/utils/chat_templates.py | 89 ++++++- src/axolotl/utils/config/__init__.py | 1 + .../config/models/input/v0_4_1/__init__.py | 129 +++++++--- src/axolotl/utils/models.py | 7 +- .../test_chat_template_utils.py | 125 +++++++++ .../prompt_strategies/test_chat_templates.py | 14 +- .../test_chat_templates_advanced.py | 26 +- .../test_dpo_chat_templates.py | 78 +++++- tests/test_validation_dataset.py | 238 ++++++++++++++++++ 20 files changed, 900 insertions(+), 118 deletions(-) create mode 100644 tests/prompt_strategies/test_chat_template_utils.py create mode 100644 tests/test_validation_dataset.py diff --git a/README.md b/README.md index 4ce7a351b..21b954a56 100644 --- a/README.md +++ b/README.md @@ -383,7 +383,7 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod - typescript type: ... # unimplemented custom format - # fastchat conversation (deprecation soon, use chat_template) + # fastchat conversation (deprecation soon, use chat_template https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template) # See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py - path: ... type: sharegpt diff --git a/docs/config.qmd b/docs/config.qmd index 703d58775..a7bf9080b 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -83,7 +83,7 @@ lora_on_cpu: true datasets: # HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files - path: vicgalle/alpaca-gpt4 - # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] + # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] type: alpaca # format | format: (chat/instruct) | .load_ ds_type: # Optional[str] (json|arrow|parquet|text|csv) defines the datatype when path is a file data_files: # Optional[str] path to source data files @@ -124,6 +124,48 @@ datasets: # For `completion` datsets only, uses the provided field instead of `text` column field: + # Using chat template + - path: ... + # Set type to `chat_template` to use this strategy + type: chat_template + # Specify the name of the chat template to use + # The name of the chat template to use for training, following values are supported: + # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default. + # - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py + # - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to if the tokenizer does not have a chat template else default to tokenizer. E.g. tokenizer_default_fallback_chatml. + # - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. + chat_template: tokenizer_default + # Custom jinja template for chat template. This will be only used if `chat_template` is set to `jinja` or empty (in which case chat_template is automatically set to `jinja`). + chat_template_jinja: + # The key in the data example that contains the messages. Default is "messages". + field_messages: messages + # The key in the message turn that contains the role. Default is "role". + message_field_role: role + # The key in the message turn that contains the content. Default is "content". + message_field_content: content + # Optional[Dict[str, List]]. Roles mapping for the messages. + roles: + user: ["human", "user"] + assistant: ["gpt", "assistant", "ai"] + system: ["system"] + + ## NOTE: Leaving the below empty will default to using the simple legacy tokenization strategy where only last message is trained on. + + # Optional[List[str]]. Roles to train on. The tokens from these roles will be considered for the loss. + roles_to_train: ["gpt", "assistant"] + # Optional[str]. Which EOS tokens to train on in the conversation. Possible values are: + # - all: train on all EOS tokens + # - turn: train on the EOS token at the end of each trainable turn + # - last: train on the last EOS token in the conversation + train_on_eos: last + # The key in the message turn that indicates via boolean whether tokens of a turn should be considered for training. Useful to selectively train on certain turns besides the `roles_to_train`. + message_field_training: training + # The key in the message turn that contains the training details. Useful to selectively train on certain tokens in a turn. + # The value of the key is a List[Dict] containing `begin_offset` (start character index in content), `end_offset` (end character index in content), and `train` (boolean whether to train). + # See example at `docs/dataset-formats/conversation.qmd` + message_field_training_detail: train_detail + + # If false, the datasets will not be shuffled and will keep their original order in `datasets`. # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true. shuffle_merged_datasets: true @@ -142,9 +184,16 @@ test_datasets: # use RL training: 'dpo', 'ipo', 'kto' rl: -# Saves the desired chat template to the tokenizer_config.json for easier inferencing -# Currently supports chatml and inst (mistral/mixtral) -chat_template: chatml +# The name of the chat template to use for training, following values are supported: +# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. +# - alpaca/inst/chatml/gemma/cohere/llama3/phi_3/deepseek_v2/jamba: These chat templates are available in the axolotl codebase at src/axolotl/utils/chat_templates.py +# - tokenizer_default_fallback_*: where * is the name of the chat template to fallback to. E.g. tokenizer_default_fallback_chatml. This is useful when the chat template is not available in the tokenizer. +# - jinja: Uses a custom jinja template for the chat template. The custom jinja template should be provided in the chat_template_jinja field. +# The selected chat template will be saved to the tokenizer_config.json for easier inferencing +# Note: It is recommended to set train_on_inputs to true when using a chat template that is different from the model's default chat template. +chat_template: tokenizer_default +# custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null. +chat_template_jinja: null # Changes the default system message default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml. # Axolotl attempts to save the dataset as an arrow after packing the data together so diff --git a/docs/dataset-formats/conversation.qmd b/docs/dataset-formats/conversation.qmd index 28d13c987..c7273c5be 100644 --- a/docs/dataset-formats/conversation.qmd +++ b/docs/dataset-formats/conversation.qmd @@ -6,6 +6,8 @@ order: 3 ## sharegpt +UPDATE: ShareGPT is being deprecated in the next release. Please see `chat_template` section below. + conversations where `from` is `human`/`gpt`. (optional: first row with role `system` to override default system prompt) ```{.json filename="data.jsonl"} @@ -69,3 +71,138 @@ creates a chat where bot is asked to tell a joke, then explain why the joke is f ```{.json filename="data.jsonl"} {"conversations": [{"title": "...", "text": "...", "explanation": "..."}]} ``` + + +## chat_template + +Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2. + +```{.json filename="data.jsonl"} +{"conversations": [{"role": "...", "content": "..."}]} +``` + +See `config.qmd` for full configs and supported templates. + +### Migrating from sharegpt + +Most configs can be adapted as follows: + +```yaml +# old +chat_template: chatml +datasets: + - path: ... + type: sharegpt + conversation: chatml + +# new (if using tokenizer's chat_template) +datasets: + - path: ... + type: chat_template + + field_messages: conversations + message_field_role: from + message_field_content: value + +# new (if setting a new chat_template like chatml, gemma, etc) +chat_template: chatml +datasets: + - path: ... + type: chat_template + + field_messages: conversations + message_field_role: from + message_field_content: value +``` + +We recommend checking the below examples for other usecases. + +### Examples + +1. Using the default chat template in the tokenizer_config.json on OpenAI messages format, training on only last message. + +```yaml +datasets: + - path: ... + type: chat_template +``` + +2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages. + +```yaml +chat_template: gemma # this overwrites the tokenizer's chat_template +datasets: + - path: ... + type: chat_template + roles_to_train: ["assistant"] +``` + +3. Using the tokenizer_config.json's chat template or `chatml` as fallback if the former's chat template does not exist, on OpenAI messages format, training on all assistant messages. + +```yaml +chat_template: tokenizer_default_fallback_chatml # this overwrites the tokenizer's chat_template +datasets: + - path: ... + type: chat_template + roles_to_train: ["assistant"] +``` + +4. Using a custom jinja template on OpenAI messages format, training on all assistant messages. + +```yaml +# chat_template: jinja # `jinja` will be implied if the `chat_template_jinja` is set and this field is empty +chat_template_jinja: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}" + +datasets: + - path: ... + type: chat_template + roles_to_train: ["assistant"] +``` + +5. (Advanced) Using fine-grained control over tokens and turns to train in a conversation + +For a data sample that looks like: + +```{.json filename="data.jsonl"} +{ + "conversations": [ + {"from": "system", "value": "You are an AI assistant.", "train": false}, + {"from": "human", "value": "Hello", "train": false}, + {"from": "assistant", "value": "Hello", "train": true}, + {"from": "human", "value": "How are you?", "train": true}, + { + "from": "assistant", + "value": "I'm doing very well, thank you!", + "train_detail": [ + {"begin_offset": 0, "end_offset": 8, "train": false}, + {"begin_offset": 9, "end_offset": 18, "train": true}, + {"begin_offset": 19, "end_offset": 30, "train": false}, + ], + }, + { + "from": "human", + "value": "I'm doing very well, thank you!", + "train": true, + }, + {"from": "assistant", "value": "Hi there!", "train": true} + ] +} +``` + +The configuration would look like: + +```yaml +datasets: + - path: ... + type: chat_template + chat_template: tokenizer_default + field_messages: conversations + message_field_role: from + message_field_content: value + roles_to_train: [] + train_on_eos: turn + message_field_training: train + message_field_training_detail: train_detail +``` + +Tip: It is not necessary to use both `message_field_training` and `message_field_training_detail` at a time. diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 77bb551f8..52765a9b5 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -30,7 +30,7 @@ from axolotl.integrations.base import PluginManager from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.comet_ import setup_comet_env_vars from axolotl.utils.config import ( normalize_cfg_datasets, @@ -272,7 +272,7 @@ def do_inference_gradio( importlib.import_module("axolotl.prompters"), prompter ) elif cfg.chat_template: - chat_template_str = chat_templates(cfg.chat_template) + chat_template_str = get_chat_template(cfg.chat_template) model = model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 319ea7be5..d125f838d 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -63,7 +63,7 @@ log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, @@ -1594,7 +1594,7 @@ def build(self, total_num_steps): training_arguments_kwargs["model_type"] = self.cfg.model_config_type training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset) if self.cfg.chat_template: - training_arguments_kwargs["chat_template"] = chat_templates( + training_arguments_kwargs["chat_template"] = get_chat_template( self.cfg.chat_template ) diff --git a/src/axolotl/prompt_strategies/bradley_terry/__init__.py b/src/axolotl/prompt_strategies/bradley_terry/__init__.py index 849d84e45..4457c50be 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/__init__.py +++ b/src/axolotl/prompt_strategies/bradley_terry/__init__.py @@ -6,7 +6,7 @@ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig -LOG = logging.getLogger("axolotl.prompt_strategies") +LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry") def load(strategy, tokenizer, cfg, ds_cfg): diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index ccda0a4bd..fa85cdcb2 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -2,13 +2,18 @@ Bradley-Terry model with chat template prompt strategy. """ +import logging from typing import Any, Dict, Optional from axolotl.prompt_strategies.chat_template import ( ChatTemplatePrompter, ChatTemplateStrategy, ) -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config + +# Configure the logger +LOG = logging.getLogger("axolotl.prompt_strategies.bradley_terry.chat_template") +LOG.setLevel(logging.INFO) class BTChatTemplateStrategy(ChatTemplateStrategy): @@ -27,18 +32,24 @@ def tokenize_prompt(self, prompt): # pylint: disable=duplicate-code prompt[self.messages] = [] if prompt["system"]: - prompt[self.messages].append({"from": "system", "value": prompt["system"]}) - prompt[self.messages].append({"from": "user", "value": prompt["input"]}) - prompt[self.messages].append({"from": "assistant", "value": prompt["chosen"]}) + prompt[self.messages].append( + {"role": "system", "content": prompt["system"]} + ) + prompt[self.messages].append({"role": "user", "content": prompt["input"]}) + prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) chosen_tokenized = super().tokenize_prompt(prompt) self.messages = "rejected_messages" # pylint: disable=duplicate-code prompt[self.messages] = [] if prompt["system"]: - prompt[self.messages].append({"from": "system", "value": prompt["system"]}) - prompt[self.messages].append({"from": "user", "value": prompt["input"]}) - prompt[self.messages].append({"from": "assistant", "value": prompt["rejected"]}) + prompt[self.messages].append( + {"role": "system", "content": prompt["system"]} + ) + prompt[self.messages].append({"role": "user", "content": prompt["input"]}) + prompt[self.messages].append( + {"role": "assistant", "content": prompt["rejected"]} + ) rejected_tokenized = super().tokenize_prompt(prompt) return { @@ -53,15 +64,18 @@ def tokenize_prompt(self, prompt): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ds_cfg = ds_cfg or {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) prompter_params = { "tokenizer": tokenizer, - "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), - "message_field_role": ds_cfg.get("message_field_role", "from"), - "message_field_content": ds_cfg.get("message_field_content", "value"), - "message_field_training": ds_cfg.get("message_field_training", "training"), + "chat_template": chat_template_string, + "message_field_role": ds_cfg.get("message_field_role", "role"), + "message_field_content": ds_cfg.get("message_field_content", "content"), + "message_field_training": ds_cfg.get("message_field_training", None), "message_field_training_detail": ds_cfg.get( - "message_field_training_detail", "train_detail" + "message_field_training_detail", None ), "roles": ds_cfg.get("roles"), "drop_system_message": ds_cfg.get("drop_system_message", False), @@ -74,8 +88,8 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): strategy_params = { "train_on_inputs": cfg.train_on_inputs, "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["gpt", "assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + "roles_to_train": ds_cfg.get("roles_to_train", []), + "train_on_eos": ds_cfg.get("train_on_eos", None), } strategy = BTChatTemplateStrategy( diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index c7852a707..0946a4b8c 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -9,7 +9,7 @@ from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, Prompter -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config # Configure the logger LOG = logging.getLogger("axolotl") @@ -405,10 +405,14 @@ def get_images(self, prompt): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): # pylint: disable=duplicate-code ds_cfg = ds_cfg or {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) + LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") prompter_params = { "tokenizer": tokenizer, - "chat_template": chat_templates(ds_cfg.get("chat_template", "chatml")), + "chat_template": chat_template_string, "message_field_role": ds_cfg.get("message_field_role", "role"), "message_field_content": ds_cfg.get("message_field_content", "content"), "message_field_training": ds_cfg.get("message_field_training", None), diff --git a/src/axolotl/prompt_strategies/dpo/chat_template.py b/src/axolotl/prompt_strategies/dpo/chat_template.py index e0e5eb129..489b86485 100644 --- a/src/axolotl/prompt_strategies/dpo/chat_template.py +++ b/src/axolotl/prompt_strategies/dpo/chat_template.py @@ -2,15 +2,16 @@ DPO prompt strategies for using tokenizer chat templates. """ -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template def default( cfg, dataset_idx=0, **kwargs ): # pylint: disable=possibly-unused-variable,unused-argument ds_cfg = cfg["datasets"][dataset_idx] - chat_template_str = chat_templates(cfg.chat_template) - + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) field_messages = ds_cfg.get("field_messages", "messages") field_chosen = ds_cfg.get("field_chosen", "chosen") field_rejected = ds_cfg.get("field_rejected", "rejected") @@ -30,6 +31,12 @@ def default( role_map[source] = target def transform_fn(sample, tokenizer=None): + chat_template_string = get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + messages = sample[field_messages] messages = [ { @@ -46,28 +53,29 @@ def transform_fn(sample, tokenizer=None): "role": role_map[sample[field_rejected][field_message_role]], "content": sample[field_rejected][field_message_content], } + dummy_user_message = {"role": "user", "content": "[[dummy_message]]"} result = {} result["prompt"] = tokenizer.apply_chat_template( messages, add_generation_prompt=True, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) result["chosen"] = tokenizer.apply_chat_template( - [chosen], + [dummy_user_message, chosen], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) chosen_strip_index = result["chosen"].find(chosen["content"]) result["chosen"] = result["chosen"][chosen_strip_index:].rstrip() result["rejected"] = tokenizer.apply_chat_template( - [rejected], + [dummy_user_message, rejected], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) rejected_strip_index = result["rejected"].find(rejected["content"]) diff --git a/src/axolotl/prompt_strategies/orpo/chat_template.py b/src/axolotl/prompt_strategies/orpo/chat_template.py index bba694856..e53a54748 100644 --- a/src/axolotl/prompt_strategies/orpo/chat_template.py +++ b/src/axolotl/prompt_strategies/orpo/chat_template.py @@ -5,7 +5,7 @@ from axolotl.prompt_tokenizers import IGNORE_INDEX, PromptTokenizingStrategy from axolotl.prompters import Prompter -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config class Message(BaseModel): @@ -28,18 +28,13 @@ def load( """ chatml transforms for datasets with system, input, chosen, rejected """ - - chat_template = chat_templates("chatml") - if ds_cfg and "chat_template" in ds_cfg: - chat_template = ds_cfg["chat_template"] - try: - chat_template = chat_templates(chat_template) - except ValueError: - pass - tokenizer.chat_template = chat_template + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) + tokenizer.chat_template = chat_template_string return ORPOTokenizingStrategy( - ORPOPrompter(chat_template, tokenizer), + ORPOPrompter(chat_template_string, tokenizer), tokenizer, cfg.train_on_inputs, cfg.sequence_len, @@ -248,28 +243,30 @@ def build_prompt( def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument dataset_parser = ORPODatasetParsingStrategy() - chat_template_str = chat_templates(cfg.chat_template) - def transform_fn(sample, tokenizer=None): res = {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, tokenizer=tokenizer + ) + res["prompt"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages], add_generation_prompt=True, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, ) prompt_str_len = len(res["prompt"]) res["chosen"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, )[prompt_str_len:] res["rejected"] = tokenizer.apply_chat_template( [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages], add_generation_prompt=False, - chat_template=chat_template_str, + chat_template=chat_template_string, tokenize=False, )[prompt_str_len:] diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index 4565c35d5..069d243f5 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -62,7 +62,7 @@ def build_loader( ): def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): LOG.warning( - "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead.", + "sharegpt type support will be deprecated in the next release of Axolotl. Please use chat_template instead. https://axolotl-ai-cloud.github.io/axolotl/docs/dataset-formats/conversation.html#chat_template", ) conversation = ( ds_cfg["conversation"] diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 2443f56f9..dfb3fef21 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -2,8 +2,19 @@ This module provides functionality for selecting chat templates based on user choices. These templates are used for formatting messages in a conversation. """ +import logging +from typing import TYPE_CHECKING, Any, Dict, Optional -CHAT_TEMPLATES = { +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase + +LOG = logging.getLogger("axolotl.utils.chat_templates") + +_JINJA_TEMPALTE_CHOICE = "jinja" +_DEFAULT_TEMPLATE_CHOICE = "tokenizer_default" +_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX = "tokenizer_default_fallback_" + +_CHAT_TEMPLATES = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large... @@ -21,12 +32,18 @@ } -def chat_templates(user_choice: str): +def get_chat_template( + user_choice: str, + jinja_template: Optional[str] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +): """ - Finds the correct chat_template for the tokenizer_config. + Finds the correct chat_template based on the user's choice, jinja_template, and tokenizer. Args: user_choice (str): The user's choice of template. + jinja_template (Optional[str], optional): The jinja template string. Defaults to None. + tokenizer (Optional[PreTrainedTokenizerBase], optional): The tokenizer. Defaults to None. Returns: str: The chosen template string. @@ -34,13 +51,71 @@ def chat_templates(user_choice: str): Raises: ValueError: If the user_choice is not found in the templates. """ + if user_choice == _JINJA_TEMPALTE_CHOICE: + if not jinja_template: + raise ValueError( + f"`jinja_template` cannot be None when `chat_template` choice is {_JINJA_TEMPALTE_CHOICE}" + ) + return jinja_template + + if user_choice == _DEFAULT_TEMPLATE_CHOICE: + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice is {_DEFAULT_TEMPLATE_CHOICE}" + ) + if not tokenizer.chat_template: + raise ValueError( + f"`chat_template choice is {_DEFAULT_TEMPLATE_CHOICE} but tokenizer's chat_template is null. " + f"Please add a chat_template in tokenizer config" + ) + return tokenizer.chat_template + + if user_choice.startswith(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX): + if not tokenizer: + raise ValueError( + f"`tokenizer` cannot be None when chat_template choice starts with {_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX}" + ) + if tokenizer.chat_template: + return tokenizer.chat_template - if user_choice in CHAT_TEMPLATES: - return CHAT_TEMPLATES[user_choice] + user_choice = user_choice[ + len(_DEFAULT_FALLBACK_CHATML_TEMPLATE_CHOICE_PREFIX) : + ] + LOG.warning( + f"No chat template found on tokenizer, falling back to {user_choice}. It is recommended to set --train_on_inputs to True for the model to learn this chat template." + ) + + if user_choice in _CHAT_TEMPLATES: + return _CHAT_TEMPLATES[user_choice] raise ValueError(f"Template '{user_choice}' not found.") +def extract_chat_template_args(cfg, ds_cfg: Optional[Dict[str, Any]] = None): + if ds_cfg and ds_cfg.get("chat_template"): + chat_template_choice = ds_cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE + chat_template_jinja = ds_cfg.get("chat_template_jinja") + else: + chat_template_choice = cfg.get("chat_template") or _DEFAULT_TEMPLATE_CHOICE + chat_template_jinja = cfg.get("chat_template_jinja") + return chat_template_choice, chat_template_jinja + + +def get_chat_template_from_config( + cfg, + ds_cfg: Optional[Dict[str, Any]] = None, + tokenizer: Optional["PreTrainedTokenizerBase"] = None, +) -> str: + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg=cfg, ds_cfg=ds_cfg + ) + return get_chat_template( + user_choice=chat_template_choice, + jinja_template=chat_template_jinja, + tokenizer=tokenizer, + ) + + def register_chat_template(template_name: str, chat_template: str): """ Registers chat templates. @@ -50,7 +125,7 @@ def register_chat_template(template_name: str, chat_template: str): chat_template (str): The template string. """ - if template_name in CHAT_TEMPLATES: + if template_name in _CHAT_TEMPLATES: raise ValueError(f"Template '{template_name}' already exists.") - CHAT_TEMPLATES[template_name] = chat_template + _CHAT_TEMPLATES[template_name] = chat_template diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index f732db06f..afc8c4fc4 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -228,6 +228,7 @@ def normalize_cfg_datasets(cfg): f"updating dataset {ds_cfg.path} with `chat_template: {cfg.chat_template}` to match your chat_template" ) cfg.datasets[idx].chat_template = cfg.chat_template + cfg.datasets[idx].chat_template_jinja = cfg.chat_template_jinja def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 16cf312ce..96e533000 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -8,9 +8,16 @@ import os from enum import Enum from importlib.metadata import version -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union -from pydantic import BaseModel, Field, conlist, field_validator, model_validator +from pydantic import ( + BaseModel, + Field, + StringConstraints, + conlist, + field_validator, + model_validator, +) from transformers import SchedulerType from transformers.training_args import OptimizerNames @@ -21,6 +28,37 @@ SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"} +class RLType(str, Enum): + """RL trainer type configuration subset""" + + dpo = "dpo" # pylint: disable=invalid-name + ipo = "ipo" # pylint: disable=invalid-name + orpo = "orpo" # pylint: disable=invalid-name + kto = "kto" # pylint: disable=invalid-name + simpo = "simpo" # pylint: disable=invalid-name + + +class ChatTemplate(str, Enum): + """Chat templates configuration subset""" + + alpaca = "alpaca" # pylint: disable=invalid-name + chatml = "chatml" # pylint: disable=invalid-name + mistral_v1 = "mistral_v1" # pylint: disable=invalid-name + mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name + mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name + gemma = "gemma" # pylint: disable=invalid-name + cohere = "cohere" # pylint: disable=invalid-name + llama3 = "llama3" # pylint: disable=invalid-name + llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name + phi_3 = "phi_3" # pylint: disable=invalid-name + phi_35 = "phi_35" # pylint: disable=invalid-name + deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name + jamba = "jamba" # pylint: disable=invalid-name + jinja = "jinja" # pylint: disable=invalid-name + qwen_25 = "qwen_25" # pylint: disable=invalid-name + tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name + + class DeprecatedParameters(BaseModel): """configurations that are deprecated""" @@ -105,13 +143,19 @@ class SFTDataset(BaseModel): input_transform: Optional[str] = None shards: Optional[int] = None conversation: Optional[str] = None - chat_template: Optional[str] = None + # Do not make this too strict or it will break the validator to choose different dataset class + chat_template: Optional[ + Union[ + ChatTemplate, + str, + ] + ] = None + chat_template_jinja: Optional[str] = None data_files: Optional[Union[str, List[str]]] = None input_format: Optional[str] = None name: Optional[str] = None ds_type: Optional[str] = None train_on_split: Optional[str] = None - field: Optional[str] = None field_human: Optional[str] = None field_model: Optional[str] = None @@ -122,13 +166,32 @@ class SFTDataset(BaseModel): message_field_training_detail: Optional[str] = None roles_to_train: Optional[List[str]] = None train_on_eos: Optional[str] = None - roles: Optional[Dict[str, List[str]]] = None drop_system_message: Optional[bool] = None - trust_remote_code: Optional[bool] = False revision: Optional[str] = None + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + # Set chat_template to tokenizer_default if not set + if data.get("type") == "chat_template" and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.tokenizer_default + + # if chat_template is set to jinja, chat_template_jinja is required + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + + return data + class UserDefinedDPOType(BaseModel): """User defined typing for DPO""" @@ -174,35 +237,6 @@ class KTODataset(BaseModel): revision: Optional[str] = None -class RLType(str, Enum): - """RL trainer type configuration subset""" - - dpo = "dpo" # pylint: disable=invalid-name - ipo = "ipo" # pylint: disable=invalid-name - orpo = "orpo" # pylint: disable=invalid-name - kto = "kto" # pylint: disable=invalid-name - simpo = "simpo" # pylint: disable=invalid-name - - -class ChatTemplate(str, Enum): - """Chat templates configuration subset""" - - alpaca = "alpaca" # pylint: disable=invalid-name - chatml = "chatml" # pylint: disable=invalid-name - mistral_v1 = "mistral_v1" # pylint: disable=invalid-name - mistral_v2v3 = "mistral_v2v3" # pylint: disable=invalid-name - mistral_v3_tekken = "mistral_v3_tekken" # pylint: disable=invalid-name - gemma = "gemma" # pylint: disable=invalid-name - cohere = "cohere" # pylint: disable=invalid-name - llama3 = "llama3" # pylint: disable=invalid-name - llama3_2_vision = "llama3_2_vision" # pylint: disable=invalid-name - phi_3 = "phi_3" # pylint: disable=invalid-name - phi_35 = "phi_35" # pylint: disable=invalid-name - deepseek_v2 = "deepseek_v2" # pylint: disable=invalid-name - jamba = "jamba" # pylint: disable=invalid-name - qwen_25 = "qwen_25" # pylint: disable=invalid-name - - class LoftQConfig(BaseModel): """LoftQ configuration subset""" @@ -719,7 +753,13 @@ class Config: gpu_memory_limit: Optional[Union[int, str]] = None low_cpu_mem_usage: Optional[bool] = None - chat_template: Optional[ChatTemplate] = None + chat_template: Optional[ + Union[ + ChatTemplate, + Annotated[str, StringConstraints(pattern="^tokenizer_default_fallback_")], + ] + ] = None + chat_template_jinja: Optional[str] = None default_system_message: Optional[str] = None fix_untrained_tokens: Optional[bool] = None @@ -828,6 +868,23 @@ def check_sample_packing_w_xformers(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_chat_template_config(cls, data): + # if chat_template is set to jinja, chat_template_jinja is required + if data.get("chat_template") == ChatTemplate.jinja and not data.get( + "chat_template_jinja" + ): + raise ValueError( + "chat_template_jinja is required when chat_template is set to jinja" + ) + + # If chat_template_jinja is set, set chat_template to jinja + if data.get("chat_template_jinja") and not data.get("chat_template"): + data["chat_template"] = ChatTemplate.jinja + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_wo_flash(cls, data): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 97844a5bf..f3386cccf 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -53,7 +53,7 @@ ) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import zero_only from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper @@ -296,7 +296,10 @@ def load_tokenizer(cfg): LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") if cfg.chat_template: - chat_template_string = chat_templates(cfg.chat_template) + chat_template_string = get_chat_template_from_config( + cfg=cfg, + tokenizer=tokenizer, + ) if cfg.default_system_message and cfg.chat_template == "chatml": chat_template_string = chat_template_string.replace( "You are a helpful assistant.", cfg.default_system_message diff --git a/tests/prompt_strategies/test_chat_template_utils.py b/tests/prompt_strategies/test_chat_template_utils.py new file mode 100644 index 000000000..b63c9aa17 --- /dev/null +++ b/tests/prompt_strategies/test_chat_template_utils.py @@ -0,0 +1,125 @@ +""" +Tests for utils in axolotl.utils.chat_templates +""" +import unittest + +import pytest +from transformers import AutoTokenizer + +from axolotl.utils.chat_templates import ( + _CHAT_TEMPLATES, + extract_chat_template_args, + get_chat_template, +) + + +@pytest.fixture(name="llama3_tokenizer") +def fixture_llama3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B") + + return tokenizer + + +class TestGetChatTemplateUtils: + """ + Tests the get_chat_template function. + """ + + def test_known_chat_template(self): + chat_template_str = get_chat_template("llama3") + assert chat_template_str == _CHAT_TEMPLATES["llama3"] + + def test_invalid_chat_template(self): + with pytest.raises(ValueError) as exc: + get_chat_template("invalid_template") + assert str(exc) == "Template 'invalid_template' not found." + + def test_tokenizer_default_no_tokenizer(self): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default", tokenizer=None) + + def test_tokenizer_default_no_chat_template_on_tokenizer(self, llama3_tokenizer): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default", tokenizer=llama3_tokenizer) + + def test_tokenizer_default_with_chat_template_on_tokenizer(self, llama3_tokenizer): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = get_chat_template( + "tokenizer_default", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + def test_tokenizer_default_fallback_no_tokenizer(self): + with pytest.raises(ValueError): + get_chat_template("tokenizer_default_fallback_test", tokenizer=None) + + def test_tokenizer_default_fallback_no_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + chat_template_str = get_chat_template( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == get_chat_template("chatml") + + def test_tokenizer_default_fallback_with_chat_template_on_tokenizer( + self, llama3_tokenizer + ): + llama3_tokenizer.chat_template = "test_template" + chat_template_str = get_chat_template( + "tokenizer_default_fallback_chatml", tokenizer=llama3_tokenizer + ) + assert chat_template_str == "test_template" + + def test_jinja_template_mode(self): + jinja_template = "example_jinja_template" + chat_template_str = get_chat_template("jinja", jinja_template=jinja_template) + assert chat_template_str == jinja_template + + def test_jinja_template_mode_no_jinja_template(self): + with pytest.raises(ValueError): + get_chat_template("jinja", jinja_template=None) + + def test_extract_chat_template_args(self): + # No ds_cfg + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={"chat_template": "chatml"}, + ) + assert chat_template_choice == "chatml" + assert chat_template_jinja is None + + # ds_cfg provided + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={ + "chat_template": "jinja", + "chat_template_jinja": "global_jinja_template", + }, + ds_cfg={"chat_template": "llama3", "chat_template_jinja": None}, + ) + assert chat_template_choice == "llama3" + assert chat_template_jinja is None + + # ds_cfg provided with jinja template + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={"chat_template": "chatml", "chat_template_jinja": None}, + ds_cfg={ + "chat_template": "jinja", + "chat_template_jinja": "ds_jinja_template", + }, + ) + assert chat_template_choice == "jinja" + assert chat_template_jinja == "ds_jinja_template" + + # ds_cfg provided with no chat_template + chat_template_choice, chat_template_jinja = extract_chat_template_args( + cfg={ + "chat_template": "jinja", + "chat_template_jinja": "global_jinja_template", + }, + ds_cfg={"chat_template": None, "chat_template_jinja": "ds_jinja_template"}, + ) + assert chat_template_choice == "jinja" + assert chat_template_jinja == "global_jinja_template" + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/prompt_strategies/test_chat_templates.py b/tests/prompt_strategies/test_chat_templates.py index 20533504c..4ec12b82c 100644 --- a/tests/prompt_strategies/test_chat_templates.py +++ b/tests/prompt_strategies/test_chat_templates.py @@ -11,7 +11,7 @@ load, ) from axolotl.prompters import IGNORE_TOKEN_ID -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template from axolotl.utils.dict import DictDefault logging.basicConfig(level=logging.DEBUG) @@ -73,7 +73,7 @@ def test_llama3(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), message_field_role="role", message_field_content="content", roles={ @@ -113,7 +113,7 @@ def test_phi35(self, phi35_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( phi35_tokenizer, - chat_template=chat_templates("phi_35"), + chat_template=get_chat_template("phi_35"), message_field_role="role", message_field_content="content", roles={ @@ -171,7 +171,7 @@ def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), message_field_role="role", message_field_content="content", message_field_training="training", @@ -230,7 +230,7 @@ def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset): # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -283,7 +283,7 @@ def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset): # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -336,7 +336,7 @@ def test_llama3_system_human(self, llama3_tokenizer, basic_dataset): # pylint: disable=duplicate-code strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, diff --git a/tests/prompt_strategies/test_chat_templates_advanced.py b/tests/prompt_strategies/test_chat_templates_advanced.py index 50429e3a2..be8e3ccdf 100644 --- a/tests/prompt_strategies/test_chat_templates_advanced.py +++ b/tests/prompt_strategies/test_chat_templates_advanced.py @@ -12,7 +12,7 @@ ChatTemplateStrategy, ) from axolotl.prompters import IGNORE_TOKEN_ID -from axolotl.utils.chat_templates import chat_templates +from axolotl.utils.chat_templates import get_chat_template logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger("axolotl") @@ -35,7 +35,7 @@ def test_train_on_inputs_true(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=True") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=True, @@ -80,7 +80,7 @@ def test_train_on_inputs_false(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_inputs=False") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -123,7 +123,7 @@ def test_roles_to_train_assistant_only(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with assistant only") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -151,7 +151,7 @@ def test_roles_to_train_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing roles_to_train with all roles") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=True, @@ -184,7 +184,7 @@ def test_empty_roles_to_train(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with empty roles_to_train") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -205,7 +205,7 @@ def test_train_on_eos_all(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='all'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -232,7 +232,7 @@ def test_train_on_eos_turn(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='turn'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -282,7 +282,7 @@ def test_train_on_eos_last(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='last'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -315,7 +315,7 @@ def test_train_on_eos_none(self, llama3_tokenizer, basic_dataset): LOG.info("Testing with train_on_eos='none'") strategy = ChatTemplateStrategy( ChatTemplatePrompter( - llama3_tokenizer, chat_template=chat_templates("llama3") + llama3_tokenizer, chat_template=get_chat_template("llama3") ), tokenizer=llama3_tokenizer, train_on_inputs=False, @@ -343,7 +343,7 @@ def test_drop_system_message(self, llama3_tokenizer, basic_dataset): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), drop_system_message=True, ), tokenizer=llama3_tokenizer, @@ -371,7 +371,7 @@ def test_custom_roles(self, llama3_tokenizer): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), roles=custom_roles, ), tokenizer=llama3_tokenizer, @@ -424,7 +424,7 @@ def test_message_field_training(self, llama3_tokenizer): strategy = ChatTemplateStrategy( ChatTemplatePrompter( llama3_tokenizer, - chat_template=chat_templates("llama3"), + chat_template=get_chat_template("llama3"), message_field_training="train", message_field_training_detail="train_detail", ), diff --git a/tests/prompt_strategies/test_dpo_chat_templates.py b/tests/prompt_strategies/test_dpo_chat_templates.py index cca48b1cf..740edc22f 100644 --- a/tests/prompt_strategies/test_dpo_chat_templates.py +++ b/tests/prompt_strategies/test_dpo_chat_templates.py @@ -86,6 +86,20 @@ def fixture_llama3_tokenizer(): return tokenizer +@pytest.fixture(name="phi3_tokenizer") +def fixture_phi3_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-medium-128k-instruct") + + return tokenizer + + +@pytest.fixture(name="gemma_tokenizer") +def fixture_gemma_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-2b-it", revision="703fb4a") + + return tokenizer + + class TestAssistantDPOChatTemplateLlama3: """ Test class for assistant style datasets with llama-3 prompts using the chat_template strategy. @@ -99,7 +113,7 @@ def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset): "chat_template": "llama3", "datasets": [ { - "chat_template": "llama3", + "type": "chat_template", } ], } @@ -124,7 +138,7 @@ def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): "chat_template": "llama3", "datasets": [ { - "chat_template": "llama3", + "type": "chat_template", "field_messages": "conversation", "field_chosen": "better", "field_rejected": "worse", @@ -152,5 +166,65 @@ def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset): assert result["rejected"] == "party on<|eot_id|>" +class TestAssistantDPOChatTemplatePhi3: + """ + Test class for assistant style datasets with phi-3 prompts using the tokenizer's chat_template strategy. + """ + + def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=phi3_tokenizer) + assert result["prompt"] == ( + "<|user|>\nhello<|end|>\n" + + "<|assistant|>\nhello<|end|>\n" + + "<|user|>\ngoodbye<|end|>\n" + + "<|assistant|>\n" + ) + assert result["chosen"] == "goodbye<|end|>" + assert result["rejected"] == "party on<|end|>" + + +class TestAssistantDPOChatTemplateGemma: + """ + Test class for assistant style datasets with gemma prompts using the tokenizer's chat_template strategy. + """ + + def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset): + # pylint: disable=duplicate-code + transform_fn = default( + DictDefault( + { + "chat_template": "tokenizer_default", + "datasets": [ + { + "type": "chat_template", + } + ], + } + ) + ) + result = transform_fn(assistant_dataset[0], tokenizer=gemma_tokenizer) + assert result["prompt"] == ( + "user\nhello\n" + + "model\nhello\n" + + "user\ngoodbye\n" + + "model\n" + ) + assert result["chosen"] == "goodbye" + assert result["rejected"] == "party on" + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_validation_dataset.py b/tests/test_validation_dataset.py new file mode 100644 index 000000000..389424217 --- /dev/null +++ b/tests/test_validation_dataset.py @@ -0,0 +1,238 @@ +"""Module for testing the validation module for the dataset config""" + +import warnings +from typing import Optional + +import pytest + +from axolotl.utils.config import validate_config +from axolotl.utils.config.models.input.v0_4_1 import ChatTemplate +from axolotl.utils.dict import DictDefault + +warnings.filterwarnings("error") + + +@pytest.fixture(name="minimal_cfg") +def fixture_cfg(): + return DictDefault( + { + "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6", + "learning_rate": 0.000001, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + } + ) + + +# pylint: disable=too-many-public-methods (duplicate-code) +class BaseValidation: + """ + Base validation module to setup the log capture + """ + + _caplog: Optional[pytest.LogCaptureFixture] = None + + @pytest.fixture(autouse=True) + def inject_fixtures(self, caplog): + self._caplog = caplog + + +class TestValidationCheckDatasetConfig(BaseValidation): + """ + Test the validation for the dataset config to ensure no correct parameters are dropped + """ + + def test_dataset_config_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "sharegpt", + "conversation": "chatml", + "shards": 10, + } + ] + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.datasets[0].conversation == cfg.datasets[0].conversation + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_default_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template is None + assert ( + checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_partial_default_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "chat_template": "chatml", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template == ChatTemplate.chatml + assert ( + checked_cfg.datasets[0].chat_template == ChatTemplate.tokenizer_default + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config() + + def test_dataset_chatml_chat_template_no_drop_param(self, minimal_cfg): + cfg = DictDefault( + minimal_cfg + | { + "chat_template": "chatml", + "datasets": [ + { + "path": "LDJnr/Puffin", + "type": "chat_template", + "chat_template": "gemma", + "field_messages": "conversations", + "shards": 10, + "message_field_role": "from", + "message_field_content": "value", + } + ], + } + ) + + checked_cfg = validate_config(cfg) + + def _check_config(): + assert checked_cfg.datasets[0].path == cfg.datasets[0].path + assert checked_cfg.datasets[0].type == cfg.datasets[0].type + assert checked_cfg.chat_template == cfg.chat_template + assert ( + checked_cfg.datasets[0].chat_template == cfg.datasets[0].chat_template + ) + assert ( + checked_cfg.datasets[0].field_messages == cfg.datasets[0].field_messages + ) + assert checked_cfg.datasets[0].shards == cfg.datasets[0].shards + assert ( + checked_cfg.datasets[0].message_field_role + == cfg.datasets[0].message_field_role + ) + assert ( + checked_cfg.datasets[0].message_field_content + == cfg.datasets[0].message_field_content + ) + + _check_config() + + checked_cfg = validate_config( + cfg, + capabilities={ + "bf16": "false", + "n_gpu": 1, + "compute_capability": "8.0", + }, + ) + + _check_config()