4
4
import pathlib
5
5
from functools import lru_cache
6
6
from importlib .metadata import version as package_version
7
+ from typing import Optional
7
8
from jinja2 import Template , TemplateError
8
9
from jinja2 .sandbox import ImmutableSandboxedEnvironment
9
10
from loguru import logger
10
11
from packaging import version
11
12
from pydantic import BaseModel
12
13
14
+ from common .utils import unwrap
15
+
13
16
14
17
class PromptTemplate (BaseModel ):
15
18
"""A template for chat completion prompts."""
@@ -18,6 +21,12 @@ class PromptTemplate(BaseModel):
18
21
template : str
19
22
20
23
24
+ class TemplateLoadError (Exception ):
25
+ """Raised on prompt template load"""
26
+
27
+ pass
28
+
29
+
21
30
def get_prompt_from_template (prompt_template : PromptTemplate , template_vars : dict ):
22
31
"""Get a prompt from a template and a list of messages."""
23
32
if version .parse (package_version ("jinja2" )) < version .parse ("3.0.0" ):
@@ -91,7 +100,7 @@ def find_template_from_model(model_path: pathlib.Path):
91
100
if template_name in model_name .lower ():
92
101
return template_name
93
102
else :
94
- raise LookupError ("Could not find template from model name." )
103
+ raise TemplateLoadError ("Could not find template from model name." )
95
104
96
105
97
106
def get_template_from_file (prompt_template_name : str ):
@@ -105,18 +114,50 @@ def get_template_from_file(prompt_template_name: str):
105
114
)
106
115
else :
107
116
# Let the user know if the template file isn't found
108
- raise FileNotFoundError (f'Template "{ prompt_template_name } " not found.' )
117
+ raise TemplateLoadError (
118
+ f'Chat template "{ prompt_template_name } " not found in files.'
119
+ )
109
120
110
121
111
122
# Get a template from a JSON file
112
123
# Requires a key and template name
113
- def get_template_from_model_json (json_path : pathlib .Path , key : str , name : str ):
124
+ def get_template_from_model_json (
125
+ json_path : pathlib .Path , key : str , name : Optional [str ] = None
126
+ ):
114
127
"""Get a template from a JSON file. Requires a key and template name"""
115
- if json_path .exists ():
116
- with open (json_path , "r" , encoding = "utf8" ) as config_file :
117
- model_config = json .load (config_file )
118
- chat_template = model_config .get (key )
119
- if chat_template :
120
- return PromptTemplate (name = name , template = chat_template )
121
- else :
122
- raise FileNotFoundError (f'Model JSON path "{ json_path } " not found.' )
128
+ if not json_path .exists ():
129
+ raise TemplateLoadError (f'Model JSON path "{ json_path } " not found.' )
130
+
131
+ with open (json_path , "r" , encoding = "utf8" ) as config_file :
132
+ model_config = json .load (config_file )
133
+ chat_template = model_config .get (key )
134
+
135
+ if not chat_template :
136
+ raise TemplateLoadError (
137
+ "Could not find a value from chat_template key in the passed JSON. "
138
+ "Check the tokenizer config?"
139
+ )
140
+
141
+ if isinstance (chat_template , list ):
142
+ # Handles the new list style of chat templates
143
+ if name :
144
+ wrapped_template = next (
145
+ (x for x in chat_template if x .get ("name" ) == name ),
146
+ {},
147
+ )
148
+ else :
149
+ wrapped_template = chat_template [0 ]
150
+ name = unwrap (wrapped_template .get ("name" ), "from_tokenizer_config" )
151
+
152
+ selected_template = wrapped_template .get ("template" )
153
+
154
+ if selected_template :
155
+ return PromptTemplate (name = name , template = selected_template )
156
+ else :
157
+ raise TemplateLoadError (
158
+ f'Chat template with name "{ name } " not found '
159
+ "in model templates list."
160
+ )
161
+ else :
162
+ # Can safely assume the chat template is the old style
163
+ return PromptTemplate (name = "from_tokenizer_config" , template = chat_template )
0 commit comments