Skip to content

Commit 46ac3be

Browse files
committed
Templates: Support list style chat_template keys
HuggingFace updated transformers to provide templates in a list for tokenizers. Update to support this new format. Providing the name of a template for the "prompt_template" value in config.yml will also look inside the template list. In addition, log if there's a template exception, but continue model loading since it shouldn't shut down the application. Signed-off-by: kingbri <bdashore3@proton.me>
1 parent 5bb4995 commit 46ac3be

File tree

3 files changed

+78
-20
lines changed

3 files changed

+78
-20
lines changed

backends/exllamav2/model.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pathlib
55
import threading
66
import time
7+
import traceback
78

89
import torch
910
from exllamav2 import (
@@ -30,6 +31,7 @@
3031
)
3132
from common.templating import (
3233
PromptTemplate,
34+
TemplateLoadError,
3335
find_template_from_model,
3436
get_template_from_model_json,
3537
get_template_from_file,
@@ -194,7 +196,7 @@ def progress(loaded_modules: int, total_modules: int,
194196
# Catch all for template lookup errors
195197
if self.prompt_template:
196198
logger.info(
197-
f"Using template {self.prompt_template.name} " "for chat completions."
199+
f'Using template "{self.prompt_template.name}" for chat completions.'
198200
)
199201
else:
200202
logger.warning(
@@ -259,23 +261,36 @@ def find_prompt_template(self, prompt_template_name, model_directory):
259261
lambda: get_template_from_model_json(
260262
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
261263
"chat_template",
262-
"from_tokenizer_config",
263264
),
264265
lambda: get_template_from_file(find_template_from_model(model_directory)),
265266
]
266267

267268
# Add lookup from prompt template name if provided
268269
if prompt_template_name:
269-
find_template_functions.insert(
270-
0, lambda: get_template_from_file(prompt_template_name)
271-
)
270+
find_template_functions[:0] = [
271+
lambda: get_template_from_file(prompt_template_name),
272+
lambda: get_template_from_model_json(
273+
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
274+
"chat_template",
275+
prompt_template_name,
276+
),
277+
]
272278

273-
for func in find_template_functions:
279+
# Continue on exception since functions are tried as they fail
280+
for template_func in find_template_functions:
274281
try:
275-
prompt_template = func()
282+
prompt_template = template_func()
276283
if prompt_template is not None:
277284
return prompt_template
278-
except (FileNotFoundError, LookupError):
285+
except TemplateLoadError as e:
286+
logger.warning(f"TemplateLoadError: {str(e)}")
287+
continue
288+
except Exception:
289+
logger.error(traceback.format_exc())
290+
logger.warning(
291+
"An unexpected error happened when trying to load the template. "
292+
"Trying other methods."
293+
)
279294
continue
280295

281296
def calculate_rope_alpha(self, base_seq_len):

common/templating.py

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import pathlib
55
from functools import lru_cache
66
from importlib.metadata import version as package_version
7+
from typing import Optional
78
from jinja2 import Template, TemplateError
89
from jinja2.sandbox import ImmutableSandboxedEnvironment
910
from loguru import logger
1011
from packaging import version
1112
from pydantic import BaseModel
1213

14+
from common.utils import unwrap
15+
1316

1417
class PromptTemplate(BaseModel):
1518
"""A template for chat completion prompts."""
@@ -18,6 +21,12 @@ class PromptTemplate(BaseModel):
1821
template: str
1922

2023

24+
class TemplateLoadError(Exception):
25+
"""Raised on prompt template load"""
26+
27+
pass
28+
29+
2130
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
2231
"""Get a prompt from a template and a list of messages."""
2332
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
@@ -91,7 +100,7 @@ def find_template_from_model(model_path: pathlib.Path):
91100
if template_name in model_name.lower():
92101
return template_name
93102
else:
94-
raise LookupError("Could not find template from model name.")
103+
raise TemplateLoadError("Could not find template from model name.")
95104

96105

97106
def get_template_from_file(prompt_template_name: str):
@@ -105,18 +114,50 @@ def get_template_from_file(prompt_template_name: str):
105114
)
106115
else:
107116
# Let the user know if the template file isn't found
108-
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
117+
raise TemplateLoadError(
118+
f'Chat template "{prompt_template_name}" not found in files.'
119+
)
109120

110121

111122
# Get a template from a JSON file
112123
# Requires a key and template name
113-
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
124+
def get_template_from_model_json(
125+
json_path: pathlib.Path, key: str, name: Optional[str] = None
126+
):
114127
"""Get a template from a JSON file. Requires a key and template name"""
115-
if json_path.exists():
116-
with open(json_path, "r", encoding="utf8") as config_file:
117-
model_config = json.load(config_file)
118-
chat_template = model_config.get(key)
119-
if chat_template:
120-
return PromptTemplate(name=name, template=chat_template)
121-
else:
122-
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
128+
if not json_path.exists():
129+
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
130+
131+
with open(json_path, "r", encoding="utf8") as config_file:
132+
model_config = json.load(config_file)
133+
chat_template = model_config.get(key)
134+
135+
if not chat_template:
136+
raise TemplateLoadError(
137+
"Could not find a value from chat_template key in the passed JSON. "
138+
"Check the tokenizer config?"
139+
)
140+
141+
if isinstance(chat_template, list):
142+
# Handles the new list style of chat templates
143+
if name:
144+
wrapped_template = next(
145+
(x for x in chat_template if x.get("name") == name),
146+
{},
147+
)
148+
else:
149+
wrapped_template = chat_template[0]
150+
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
151+
152+
selected_template = wrapped_template.get("template")
153+
154+
if selected_template:
155+
return PromptTemplate(name=name, template=selected_template)
156+
else:
157+
raise TemplateLoadError(
158+
f'Chat template with name "{name}" not found '
159+
"in model templates list."
160+
)
161+
else:
162+
# Can safely assume the chat template is the old style
163+
return PromptTemplate(name="from_tokenizer_config", template=chat_template)

config_sample.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ model:
107107
# Possible values FP16, FP8, Q4. (default: FP16)
108108
#cache_mode: FP16
109109

110-
# Set the prompt template for this model. If empty, chat completions will be disabled. (default: Empty)
110+
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
111+
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
112+
# of the template you want to use.
111113
# NOTE: Only works with chat completion message lists!
112114
#prompt_template:
113115

0 commit comments

Comments
 (0)