Skip to content

Commit

Permalink
fix: simplify naming, tool choice default and improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Oct 16, 2024
1 parent d85133e commit 6837c5b
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 88 deletions.
78 changes: 39 additions & 39 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -829,43 +829,6 @@
}
}
},
"ChatCompletionToolChoiceOption": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
],
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
},
"ChatCompletionTopLogprob": {
"type": "object",
"required": [
Expand Down Expand Up @@ -960,7 +923,7 @@
"$ref": "#/components/schemas/GrammarType"
}
],
"default": "null",
"default": "auto",
"nullable": true
},
"seed": {
Expand Down Expand Up @@ -1000,7 +963,7 @@
"tool_choice": {
"allOf": [
{
"$ref": "#/components/schemas/ChatCompletionToolChoiceOption"
"$ref": "#/components/schemas/ToolChoice"
}
],
"default": "null",
Expand Down Expand Up @@ -2141,6 +2104,43 @@
}
}
},
"ToolChoice": {
"oneOf": [
{
"type": "string",
"description": "Means the model can pick between generating a message or calling one or more tools.",
"enum": [
"auto"
]
},
{
"type": "string",
"description": "Means the model will not call any tool and instead generates a message.",
"enum": [
"none"
]
},
{
"type": "string",
"description": "Means the model must call one or more tools.",
"enum": [
"required"
]
},
{
"type": "object",
"required": [
"function"
],
"properties": {
"function": {
"$ref": "#/components/schemas/FunctionName"
}
}
}
],
"description": "<https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>"
},
"Url": {
"type": "object",
"required": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"logprobs": null
}
],
"created": 1729000499,
"created": 1729084854,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"logprobs": null
}
],
"created": 1728998230,
"created": 1729084850,
"id": "",
"model": "meta-llama/Llama-3.1-8B-Instruct",
"object": "chat.completion.chunk",
Expand Down
7 changes: 3 additions & 4 deletions integration-tests/models/test_tools_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
"tools": tools,
"tool_choice": {
"type": "function",
"function": {"name": "get_current_weather"},
"function": {"name": "get_n_day_weather_forecast"},
},
"seed": 24,
"max_tokens": 100,
Expand All @@ -421,10 +421,9 @@ async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
]["arguments"]
last_response = response

assert count == 30
print(tool_calls_generated)
assert count == 39
assert (
tool_calls_generated
== '{"function": {"_name": "get_current_weather", "format": "celsius", "location": "Tokyo, JP"}}<|eot_id|>'
== '{"function": {"_name": "get_n_day_weather_forecast", "format": "celsius", "location": "San Francisco, CA", "num_days":3}}<|eot_id|>'
)
assert last_response == response_snapshot
17 changes: 8 additions & 9 deletions router/src/infer/tool_grammar.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::infer::InferError;
use crate::{
ChatCompletionToolChoiceOption, FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool,
Properties, Tool,
FunctionDefinition, FunctionRef, FunctionsMap, JsonSchemaTool, Properties, Tool, ToolChoice,
};
use serde_json::{json, Map, Value};
use std::collections::HashMap;
Expand All @@ -20,19 +19,19 @@ impl ToolGrammar {

pub fn apply(
tools: Vec<Tool>,
tool_choice: ChatCompletionToolChoiceOption,
tool_choice: ToolChoice,
) -> Result<(Vec<Tool>, Option<JsonSchemaTool>), InferError> {
// if no tools are provided, we return None
// if no tools are provided, we return None and an empty vec
if tools.is_empty() {
return Ok((tools, None));
return Ok((Vec::with_capacity(0), None));
}

let tools_to_use = match tool_choice {
ChatCompletionToolChoiceOption::Function(function) => {
ToolChoice::Function(function) => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ChatCompletionToolChoiceOption::Required => tools,
ChatCompletionToolChoiceOption::Auto => {
ToolChoice::Required => tools,
ToolChoice::Auto => {
// only add the no_tool function if the user has selected the auto option
tools
.iter()
Expand All @@ -58,7 +57,7 @@ impl ToolGrammar {
}))
.collect::<Vec<_>>()
}
ChatCompletionToolChoiceOption::NoTool => Vec::with_capacity(0),
ToolChoice::NoTool => Vec::with_capacity(0),
};

let functions: HashMap<String, serde_json::Value> = tools_to_use
Expand Down
52 changes: 20 additions & 32 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,13 +850,13 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
pub tool_choice: Option<ChatCompletionToolChoiceOption>,
pub tool_choice: Option<ToolChoice>,

/// Response format constraints for the generation.
///
/// NOTE: A request can use `response_format` OR `tools` but not both.
#[serde(default)]
#[schema(nullable = true, default = "null", example = "null")]
#[schema(nullable = true, default = "auto", example = "auto")]
pub response_format: Option<GrammarType>,

/// A guideline to be used in the chat_template
Expand Down Expand Up @@ -903,14 +903,8 @@ impl ChatRequest {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};
// unwrap or default (use "auto" if tools are present, and "none" if not)
let tool_choice = tool_choice.unwrap_or_else(|| {
if tools.is_some() {
ChatCompletionToolChoiceOption::Auto
} else {
ChatCompletionToolChoiceOption::NoTool
}
});
// if no tool_choice is set, set default (Auto)
let tool_choice = tool_choice.unwrap_or_default();

if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
Expand Down Expand Up @@ -1002,21 +996,18 @@ pub struct FunctionName {

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema, Default)]
#[serde(from = "ToolTypeDeserializer")]
#[serde(rename_all = "snake_case")]
/// <https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter>
pub enum ChatCompletionToolChoiceOption {
pub enum ToolChoice {
/// Means the model can pick between generating a message or calling one or more tools.
#[schema(rename = "auto")]
#[default]
Auto,
/// Means the model will not call any tool and instead generates a message.
#[schema(rename = "none")]
#[default]
NoTool,
/// Means the model must call one or more tools.
#[schema(rename = "required")]
Required,
/// Forces the model to call a specific tool. This structure aligns with the `OpenAI` API schema to force a specific tool.
#[schema(rename = "function")]
#[serde(alias = "function")]
Function(FunctionName),
}

Expand All @@ -1042,18 +1033,18 @@ enum ToolTypeDeserializer {
TypedChoice(TypedChoice),
}

impl From<ToolTypeDeserializer> for ChatCompletionToolChoiceOption {
impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self {
match value {
ToolTypeDeserializer::Null => ChatCompletionToolChoiceOption::NoTool,
ToolTypeDeserializer::Null => ToolChoice::NoTool,
ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ChatCompletionToolChoiceOption::NoTool,
"auto" => ChatCompletionToolChoiceOption::Auto,
"required" => ChatCompletionToolChoiceOption::Required,
_ => ChatCompletionToolChoiceOption::Function(FunctionName { name: s }),
"none" => ToolChoice::NoTool,
"auto" => ToolChoice::Auto,
"required" => ToolChoice::Required,
_ => ToolChoice::Function(FunctionName { name: s }),
},
ToolTypeDeserializer::TypedChoice(TypedChoice::Function { function }) => {
ChatCompletionToolChoiceOption::Function(function)
ToolChoice::Function(function)
}
}
}
Expand Down Expand Up @@ -1667,26 +1658,23 @@ mod tests {
fn tool_choice_formats() {
#[derive(Deserialize)]
struct TestRequest {
tool_choice: ChatCompletionToolChoiceOption,
tool_choice: ToolChoice,
}

let de_none: TestRequest = serde_json::from_str(r#"{"tool_choice":"none"}"#).unwrap();
assert_eq!(de_none.tool_choice, ChatCompletionToolChoiceOption::NoTool);
assert_eq!(de_none.tool_choice, ToolChoice::NoTool);

let de_auto: TestRequest = serde_json::from_str(r#"{"tool_choice":"auto"}"#).unwrap();
assert_eq!(de_auto.tool_choice, ChatCompletionToolChoiceOption::Auto);
assert_eq!(de_auto.tool_choice, ToolChoice::Auto);

let de_required: TestRequest =
serde_json::from_str(r#"{"tool_choice":"required"}"#).unwrap();
assert_eq!(
de_required.tool_choice,
ChatCompletionToolChoiceOption::Required
);
assert_eq!(de_required.tool_choice, ToolChoice::Required);

let de_named: TestRequest = serde_json::from_str(r#"{"tool_choice":"myfn"}"#).unwrap();
assert_eq!(
de_named.tool_choice,
ChatCompletionToolChoiceOption::Function(FunctionName {
ToolChoice::Function(FunctionName {
name: "myfn".to_string(),
})
);
Expand All @@ -1697,7 +1685,7 @@ mod tests {
.unwrap();
assert_eq!(
de_openai_named.tool_choice,
ChatCompletionToolChoiceOption::Function(FunctionName {
ToolChoice::Function(FunctionName {
name: "myfn".to_string(),
})
);
Expand Down
4 changes: 2 additions & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
};
use crate::{ChatCompletionToolChoiceOption, FunctionDefinition, HubPreprocessorConfig, ToolCall};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
use crate::{ModelInfo, ModelsInfo};
use async_stream::__private::AsyncStream;
use axum::extract::Extension;
Expand Down Expand Up @@ -1563,7 +1563,7 @@ Tool,
ToolCall,
Function,
FunctionDefinition,
ChatCompletionToolChoiceOption,
ToolChoice,
ModelInfo,
)
),
Expand Down

0 comments on commit 6837c5b

Please sign in to comment.