Skip to content

Commit 376a0b7

Browse files
authored
Support chat response format (#2046)
* feat: support response_format in chat * fix: adjust typos * fix: add trufflehog lint
1 parent a6e4d63 commit 376a0b7

File tree

5 files changed

+156
-7
lines changed

5 files changed

+156
-7
lines changed

.github/workflows/trufflehog.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@ jobs:
1616
fetch-depth: 0
1717
- name: Secret Scanning
1818
uses: trufflesecurity/trufflehog@main
19-
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"choices": [
3+
{
4+
"finish_reason": "eos_token",
5+
"index": 0,
6+
"logprobs": null,
7+
"message": {
8+
"content": "{\n \"temperature\": [\n 35,\n 34,\n 36\n ],\n \"unit\": \"°c\"\n}",
9+
"role": "assistant"
10+
}
11+
}
12+
],
13+
"created": 1718044128,
14+
"id": "",
15+
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
16+
"object": "text_completion",
17+
"system_fingerprint": "2.0.5-dev0-native",
18+
"usage": {
19+
"completion_tokens": 39,
20+
"prompt_tokens": 136,
21+
"total_tokens": 175
22+
}
23+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import pytest
2+
import requests
3+
from pydantic import BaseModel
4+
from typing import List
5+
6+
7+
@pytest.fixture(scope="module")
8+
def llama_grammar_handle(launcher):
9+
with launcher(
10+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
11+
num_shard=1,
12+
disable_grammar_support=False,
13+
use_flash_attention=False,
14+
max_batch_prefill_tokens=3000,
15+
) as handle:
16+
yield handle
17+
18+
19+
@pytest.fixture(scope="module")
20+
async def llama_grammar(llama_grammar_handle):
21+
await llama_grammar_handle.health(300)
22+
return llama_grammar_handle.client
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_grammar_response_format_llama_json(llama_grammar, response_snapshot):
27+
28+
class Weather(BaseModel):
29+
unit: str
30+
temperature: List[int]
31+
32+
# send the request
33+
response = requests.post(
34+
f"{llama_grammar.base_url}/v1/chat/completions",
35+
headers=llama_grammar.headers,
36+
json={
37+
"model": "tgi",
38+
"messages": [
39+
{
40+
"role": "system",
41+
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
42+
},
43+
{
44+
"role": "user",
45+
"content": "What's the weather like the next 3 days in San Francisco, CA?",
46+
},
47+
],
48+
"seed": 42,
49+
"max_tokens": 500,
50+
"response_format": {"type": "json_object", "value": Weather.schema()},
51+
},
52+
)
53+
54+
chat_completion = response.json()
55+
called = chat_completion["choices"][0]["message"]["content"]
56+
57+
assert response.status_code == 200
58+
assert (
59+
called
60+
== '{\n "temperature": [\n 35,\n 34,\n 36\n ],\n "unit": "°c"\n}'
61+
)
62+
assert chat_completion == response_snapshot
63+
64+
65+
@pytest.mark.asyncio
66+
async def test_grammar_response_format_llama_error_if_tools_not_installed(
67+
llama_grammar,
68+
):
69+
class Weather(BaseModel):
70+
unit: str
71+
temperature: List[int]
72+
73+
# send the request
74+
response = requests.post(
75+
f"{llama_grammar.base_url}/v1/chat/completions",
76+
headers=llama_grammar.headers,
77+
json={
78+
"model": "tgi",
79+
"messages": [
80+
{
81+
"role": "system",
82+
"content": f"Respond to the users questions and answer them in the following format: {Weather.schema()}",
83+
},
84+
{
85+
"role": "user",
86+
"content": "What's the weather like the next 3 days in San Francisco, CA?",
87+
},
88+
],
89+
"seed": 42,
90+
"max_tokens": 500,
91+
"tools": [],
92+
"response_format": {"type": "json_object", "value": Weather.schema()},
93+
},
94+
)
95+
96+
# 422 means the server was unable to process the request because it contains invalid data.
97+
assert response.status_code == 422
98+
assert response.json() == {
99+
"error": "Grammar and tools are mutually exclusive",
100+
"error_type": "grammar and tools",
101+
}

router/src/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ pub(crate) enum GrammarType {
8989
/// JSON Schema is a declarative language that allows to annotate JSON documents
9090
/// with types and descriptions.
9191
#[serde(rename = "json")]
92+
#[serde(alias = "json_object")]
9293
#[schema(example = json ! ({"properties": {"location":{"type": "string"}}}))]
9394
Json(serde_json::Value),
9495
#[serde(rename = "regex")]
@@ -791,6 +792,13 @@ pub(crate) struct ChatRequest {
791792
#[schema(nullable = true, example = "null")]
792793
#[serde(deserialize_with = "deserialize_tool_choice::deserialize")]
793794
pub tool_choice: Option<ToolType>,
795+
796+
/// Response format constraints for the generation.
797+
///
798+
/// NOTE: A request can use `response_format` OR `tools` but not both.
799+
#[serde(default)]
800+
#[schema(nullable = true, default = "null", example = "null")]
801+
pub response_format: Option<GrammarType>,
794802
}
795803

796804
fn default_tool_prompt() -> Option<String> {

router/src/server.rs

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,7 @@ async fn chat_completions(
10161016
tool_choice,
10171017
tool_prompt,
10181018
temperature,
1019+
response_format,
10191020
..
10201021
} = req;
10211022

@@ -1030,6 +1031,18 @@ async fn chat_completions(
10301031
other => (true, other),
10311032
};
10321033

1034+
// response_format and tools are mutually exclusive
1035+
if response_format.is_some() && tools.as_ref().is_some() {
1036+
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
1037+
return Err((
1038+
StatusCode::UNPROCESSABLE_ENTITY,
1039+
Json(ErrorResponse {
1040+
error: "Grammar and tools are mutually exclusive".to_string(),
1041+
error_type: "grammar and tools".to_string(),
1042+
}),
1043+
));
1044+
}
1045+
10331046
// extract tool grammar if present
10341047
let tool_grammar = match ToolGrammar::apply(tools, tool_choice) {
10351048
Ok(grammar) => grammar,
@@ -1046,16 +1059,21 @@ async fn chat_completions(
10461059
}
10471060
};
10481061

1049-
let grammar_with_prompt = tool_grammar
1062+
// determine the appropriate arguments for apply_chat_template
1063+
let tools_grammar_prompt = tool_grammar
10501064
.as_ref()
10511065
.map(|t| (GrammarType::Json(serde_json::json!(t)), tool_prompt));
10521066

1053-
let typed_grammar = grammar_with_prompt
1054-
.as_ref()
1055-
.map(|(grammar, _)| grammar.clone());
1067+
let (tools_grammar_prompt, grammar) = match response_format {
1068+
Some(response_format) => (None, Some(response_format)),
1069+
None => (
1070+
tools_grammar_prompt.clone(),
1071+
tools_grammar_prompt.map(|(grammar, _)| grammar.clone()),
1072+
),
1073+
};
10561074

10571075
// apply chat template to flatten the request into a single input
1058-
let inputs = match infer.apply_chat_template(messages, grammar_with_prompt) {
1076+
let inputs = match infer.apply_chat_template(messages, tools_grammar_prompt) {
10591077
Ok(inputs) => inputs,
10601078
Err(err) => {
10611079
metrics::increment_counter!("tgi_request_failure", "err" => "validation");
@@ -1091,7 +1109,7 @@ async fn chat_completions(
10911109
decoder_input_details: !stream,
10921110
seed,
10931111
top_n_tokens: req.top_logprobs,
1094-
grammar: typed_grammar,
1112+
grammar,
10951113
},
10961114
};
10971115

0 commit comments

Comments
 (0)