Skip to content

Commit

Permalink
Templates: Support list style chat_template keys
Browse files Browse the repository at this point in the history
HuggingFace updated transformers to provide templates in a list for
tokenizers. Update to support this new format. Providing the name
of a template for the "prompt_template" value in config.yml will also
look inside the template list.

In addition, log if there's a template exception, but continue model
loading since it shouldn't shut down the application.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Apr 7, 2024
1 parent 5bb4995 commit 46ac3be
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 20 deletions.
31 changes: 23 additions & 8 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import threading
import time
import traceback

import torch
from exllamav2 import (
Expand All @@ -30,6 +31,7 @@
)
from common.templating import (
PromptTemplate,
TemplateLoadError,
find_template_from_model,
get_template_from_model_json,
get_template_from_file,
Expand Down Expand Up @@ -194,7 +196,7 @@ def progress(loaded_modules: int, total_modules: int,
# Catch all for template lookup errors
if self.prompt_template:
logger.info(
f"Using template {self.prompt_template.name} " "for chat completions."
f'Using template "{self.prompt_template.name}" for chat completions.'
)
else:
logger.warning(
Expand Down Expand Up @@ -259,23 +261,36 @@ def find_prompt_template(self, prompt_template_name, model_directory):
lambda: get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
"from_tokenizer_config",
),
lambda: get_template_from_file(find_template_from_model(model_directory)),
]

# Add lookup from prompt template name if provided
if prompt_template_name:
find_template_functions.insert(
0, lambda: get_template_from_file(prompt_template_name)
)
find_template_functions[:0] = [
lambda: get_template_from_file(prompt_template_name),
lambda: get_template_from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
prompt_template_name,
),
]

for func in find_template_functions:
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = func()
prompt_template = template_func()
if prompt_template is not None:
return prompt_template
except (FileNotFoundError, LookupError):
except TemplateLoadError as e:
logger.warning(f"TemplateLoadError: {str(e)}")
continue
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"An unexpected error happened when trying to load the template. "
"Trying other methods."
)
continue

def calculate_rope_alpha(self, base_seq_len):
Expand Down
63 changes: 52 additions & 11 deletions common/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from typing import Optional
from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
from pydantic import BaseModel

from common.utils import unwrap


class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
Expand All @@ -18,6 +21,12 @@ class PromptTemplate(BaseModel):
template: str


class TemplateLoadError(Exception):
"""Raised on prompt template load"""

pass


def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
Expand Down Expand Up @@ -91,7 +100,7 @@ def find_template_from_model(model_path: pathlib.Path):
if template_name in model_name.lower():
return template_name
else:
raise LookupError("Could not find template from model name.")
raise TemplateLoadError("Could not find template from model name.")


def get_template_from_file(prompt_template_name: str):
Expand All @@ -105,18 +114,50 @@ def get_template_from_file(prompt_template_name: str):
)
else:
# Let the user know if the template file isn't found
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
)


# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
def get_template_from_model_json(
json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if json_path.exists():
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(name=name, template=chat_template)
else:
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')

with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)

if not chat_template:
raise TemplateLoadError(
"Could not find a value from chat_template key in the passed JSON. "
"Check the tokenizer config?"
)

if isinstance(chat_template, list):
# Handles the new list style of chat templates
if name:
wrapped_template = next(
(x for x in chat_template if x.get("name") == name),
{},
)
else:
wrapped_template = chat_template[0]
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")

selected_template = wrapped_template.get("template")

if selected_template:
return PromptTemplate(name=name, template=selected_template)
else:
raise TemplateLoadError(
f'Chat template with name "{name}" not found '
"in model templates list."
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(name="from_tokenizer_config", template=chat_template)
4 changes: 3 additions & 1 deletion config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ model:
# Possible values FP16, FP8, Q4. (default: FP16)
#cache_mode: FP16

# Set the prompt template for this model. If empty, chat completions will be disabled. (default: Empty)
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
# of the template you want to use.
# NOTE: Only works with chat completion message lists!
#prompt_template:

Expand Down

0 comments on commit 46ac3be

Please sign in to comment.