Skip to content

Commit d52c60b

Browse files
committed
encapsulate chat template logic
1 parent 88e0813 commit d52c60b

File tree

1 file changed

+72
-48
lines changed

1 file changed

+72
-48
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 72 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,54 +1764,7 @@ def apply_chat_template(
17641764
if tokenizer_kwargs is None:
17651765
tokenizer_kwargs = {}
17661766

1767-
using_default_template = False
1768-
1769-
# First, handle the cases when the model has a dict of multiple templates
1770-
if isinstance(self.chat_template, dict) or (
1771-
self.chat_template is None and isinstance(self.default_chat_template, dict)
1772-
):
1773-
if self.chat_template is not None:
1774-
template_dict = self.chat_template
1775-
using_default_dict = False
1776-
else:
1777-
template_dict = self.default_chat_template
1778-
using_default_dict = True
1779-
if chat_template is not None and chat_template in template_dict:
1780-
# The user can pass the name of a template to the chat template argument instead of an entire template
1781-
chat_template = template_dict[chat_template]
1782-
if using_default_dict:
1783-
using_default_template = True
1784-
elif chat_template is None:
1785-
if tools is not None and "tool_use" in template_dict:
1786-
chat_template = template_dict["tool_use"]
1787-
elif "default" in template_dict:
1788-
chat_template = template_dict["default"]
1789-
else:
1790-
raise ValueError(
1791-
"This model has multiple chat templates with no default specified! Please either pass a chat "
1792-
"template or the name of the template you wish to use to the `chat_template` argument. Available "
1793-
f"template names are {sorted(template_dict.keys())}."
1794-
)
1795-
if using_default_dict:
1796-
using_default_template = True
1797-
1798-
elif chat_template is None:
1799-
# These are the cases when the model has a single template
1800-
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
1801-
if self.chat_template is not None:
1802-
chat_template = self.chat_template
1803-
else:
1804-
chat_template = self.default_chat_template
1805-
using_default_template = True
1806-
1807-
if using_default_template:
1808-
logger.warning_once(
1809-
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
1810-
"very error-prone, because models are often trained with templates different from the class default! "
1811-
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
1812-
"point any code depending on them will stop working. We recommend setting a valid chat template before "
1813-
"then to ensure that this model continues working without issues."
1814-
)
1767+
chat_template = self.get_chat_template(chat_template, tools)
18151768

18161769
# Compilation function uses a cache to avoid recompiling the same template
18171770
compiled_template = self._compile_jinja_template(chat_template)
@@ -1908,6 +1861,77 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
19081861
jinja_env.globals["raise_exception"] = raise_exception
19091862
return jinja_env.from_string(chat_template)
19101863

1864+
def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
1865+
"""
1866+
Retrieve the chat template string used for tokenizing chat messages. This template is used
1867+
internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat
1868+
template for better generation tracking.
1869+
1870+
Args:
1871+
chat_template (`str`, *optional*):
1872+
A Jinja template or the name of a template to use for this conversion.
1873+
It is usually not necessary to pass anything to this argument,
1874+
as the model's template will be used by default.
1875+
tools (`List[Dict]`, *optional*):
1876+
A list of tools (callable functions) that will be accessible to the model. If the template does not
1877+
support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
1878+
giving the name, description and argument types for the tool. See our
1879+
[chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
1880+
for more information.
1881+
1882+
Returns:
1883+
`str`: The chat template string.
1884+
"""
1885+
using_default_template = False
1886+
# First, handle the cases when the model has a dict of multiple templates
1887+
if isinstance(self.chat_template, dict) or (
1888+
self.chat_template is None and isinstance(self.default_chat_template, dict)
1889+
):
1890+
if self.chat_template is not None:
1891+
template_dict = self.chat_template
1892+
using_default_dict = False
1893+
else:
1894+
template_dict = self.default_chat_template
1895+
using_default_dict = True
1896+
if chat_template is not None and chat_template in template_dict:
1897+
# The user can pass the name of a template to the chat template argument instead of an entire template
1898+
chat_template = template_dict[chat_template]
1899+
if using_default_dict:
1900+
using_default_template = True
1901+
elif chat_template is None:
1902+
if tools is not None and "tool_use" in template_dict:
1903+
chat_template = template_dict["tool_use"]
1904+
elif "default" in template_dict:
1905+
chat_template = template_dict["default"]
1906+
else:
1907+
raise ValueError(
1908+
"This model has multiple chat templates with no default specified! Please either pass a chat "
1909+
"template or the name of the template you wish to use to the `chat_template` argument. Available "
1910+
f"template names are {sorted(template_dict.keys())}."
1911+
)
1912+
if using_default_dict:
1913+
using_default_template = True
1914+
1915+
elif chat_template is None:
1916+
# These are the cases when the model has a single template
1917+
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
1918+
if self.chat_template is not None:
1919+
chat_template = self.chat_template
1920+
else:
1921+
chat_template = self.default_chat_template
1922+
using_default_template = True
1923+
1924+
if using_default_template:
1925+
logger.warning_once(
1926+
"No chat template is set for this tokenizer, falling back to a default class-level template. This is "
1927+
"very error-prone, because models are often trained with templates different from the class default! "
1928+
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
1929+
"point any code depending on them will stop working. We recommend setting a valid chat template before "
1930+
"then to ensure that this model continues working without issues."
1931+
)
1932+
1933+
return chat_template
1934+
19111935
@property
19121936
def default_chat_template(self):
19131937
"""

0 commit comments

Comments
 (0)