Skip to content

tool-call: allow --chat-template chatml w/ --jinja, default to chatml upon parsing issue, avoid double bos #11616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1869,11 +1869,19 @@ std::string common_chat_format_example(const common_chat_template & tmpl, bool u
return common_chat_apply_template(tmpl, msgs, true, use_jinja);
}

#define CHATML_TEMPLATE_SRC \
"{%- for message in messages -%}\n" \
" {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \
"{%- endfor -%}\n" \
"{%- if add_generation_prompt -%}\n" \
" {{- '<|im_start|>assistant\n' -}}\n" \
"{%- endif -%}"

common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override)
{
auto vocab = llama_model_get_vocab(model);
std::string default_template_src = chat_template_override;
std::string template_tool_use_src = chat_template_override;
std::string default_template_src;
std::string template_tool_use_src;

bool has_explicit_template = !chat_template_override.empty();
if (chat_template_override.empty()) {
auto str = llama_model_chat_template(model, /* name */ nullptr);
Expand All @@ -1886,21 +1894,21 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
template_tool_use_src = str;
has_explicit_template = true;
}
} else {
default_template_src = chat_template_override;
}
if (default_template_src.empty() || default_template_src == "chatml") {
if (!template_tool_use_src.empty()) {
default_template_src = template_tool_use_src;
} else {
default_template_src = R"(
{%- for message in messages -%}
{{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- "<|im_start|>assistant\n" -}}
{%- endif -%}
)";
default_template_src = CHATML_TEMPLATE_SRC;
}
}
std::string token_bos;
std::string token_eos;
// TODO: update logic that adds BOS and EOS tokens to the tokenized prompt, in favour of the template.
#if 0
auto vocab = llama_model_get_vocab(model);
const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) {
if (token == LLAMA_TOKEN_NULL) {
if (default_template_src.find(jinja_variable_name) != std::string::npos
Expand All @@ -1912,15 +1920,25 @@ common_chat_templates common_chat_templates_from_model(const struct llama_model
return common_token_to_piece(vocab, token, true);
}
};
auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos)
};
token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token");
token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token");
#endif
try {
return {
has_explicit_template,
std::make_unique<minja::chat_template>(default_template_src, token_bos, token_eos),
template_tool_use_src.empty()
? nullptr
: std::make_unique<minja::chat_template>(template_tool_use_src, token_bos, token_eos),
};
} catch (const std::exception & e) {
LOG_ERR("%s: failed to parse chat template: %s\n", __func__, e.what());
return {
has_explicit_template,
std::make_unique<minja::chat_template>(CHATML_TEMPLATE_SRC, token_bos, token_eos),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at some point we should no longer fallback to chatml. The fallback to chatml was a temporary solution when chat templates was not a common thing.

For example, in such case, we can return an error message like: Chat template is not supported, you must specify a custom template using --chat-template ... when user uses /chat/completions endpoint.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either way, it's surprising all the things we can have chatml do with a few "polyfills" (in minja)

nullptr,
};
}
}

//
Expand Down
7 changes: 5 additions & 2 deletions examples/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ def create_server():
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
[
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
]
Expand Down
98 changes: 81 additions & 17 deletions examples/server/tests/unit/test_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def create_server():


def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None):
n_predict = 512
global server
n_predict = 512
# server = ServerPreset.stories15m_moe()
server.jinja = True
server.n_predict = n_predict
Expand Down Expand Up @@ -139,40 +139,62 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),

# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),

(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),

(TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),

(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),

(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),

(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),

(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),

(TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),

(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
# TODO: fix these
# (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None):
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if template_override:
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predict,
Expand Down Expand Up @@ -252,29 +274,49 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
@pytest.mark.slow
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),

("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),

("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),

("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),

("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),

("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),

("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),

("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),

# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),

# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
# ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | None] | None):
def test_weather_tool_call(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = 512
server.model_hf_repo = hf_repo
server.model_hf_file = None
if template_override:
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 256,
Expand All @@ -298,30 +340,52 @@ def test_weather_tool_call(hf_repo: str, template_override: Tuple[str, str | Non

@pytest.mark.slow
@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),

(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),

(None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),

('{"code":"print("}', "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
(None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),

('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),

(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),

(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),

(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
(None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),

(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),

# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
(None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),

# (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: Tuple[str, str | None] | None):
def test_hello_world_tool_call(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = 128
server.model_hf_repo = hf_repo
server.model_hf_file = None
if template_override:
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 256,
Expand Down
Loading