Skip to content

Commit

Permalink
fix: adjust default tool choice (#2244)
Browse files Browse the repository at this point in the history
* fix: adjust default tool choice

* feat: improve tool choice syntax and response parsing/errors

* fix: remove dev tests

* feat: add ToolChoice to docs
  • Loading branch information
drbh authored and ErikKaum committed Jul 26, 2024
1 parent 26194ad commit 6bdf8d7
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 154 deletions.
15 changes: 14 additions & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@
"tool_choice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
"$ref": "#/components/schemas/ToolChoice"
}
],
"nullable": true
Expand Down Expand Up @@ -2035,6 +2035,14 @@
}
}
},
"ToolChoice": {
"allOf": [
{
"$ref": "#/components/schemas/ToolType"
}
],
"nullable": true
},
"ToolType": {
"oneOf": [
{
Expand All @@ -2055,6 +2063,11 @@
"$ref": "#/components/schemas/FunctionName"
}
}
},
{
"type": "object",
"default": null,
"nullable": true
}
]
},
Expand Down
233 changes: 119 additions & 114 deletions router/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub(crate) use health::HealthCheck;
use crate::validation::{ValidGenerateRequest, Validation, ValidationError};
use crate::{
ChatTemplateInputs, ChatTemplateVersions, FinishReason, GenerateRequest, HubProcessorConfig,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token,
HubTokenizerConfig, Message, MessageChunk, PrefillToken, TextMessage, Token, ToolChoice,
};
use crate::{
FunctionRef, FunctionsMap, GrammarType, Properties, TokenizerConfigToken, Tool, ToolType, Tools,
Expand Down Expand Up @@ -332,126 +332,131 @@ impl ChatTemplate {
pub struct ToolGrammar {}

impl ToolGrammar {
// find a tool by name
fn find_tool_by_name(tools: &[Tool], name: &str) -> Result<Tool, InferError> {
tools
.iter()
.find(|tool| tool.function.name == name)
.cloned()
.ok_or_else(|| InferError::ToolError(format!("Tool with name {} not found", name)))
}

pub fn apply(
tools: Option<Vec<Tool>>,
tool_choice: Option<ToolType>,
tool_choice: ToolChoice,
) -> Result<Option<Tools>, InferError> {
if let Some((req_tools, tool_choice)) = tools.zip(tool_choice) {
// let tool_prompt = tool_prompt.unwrap_or_default();
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![req_tools
.iter()
.find(|tool| tool.function.name == *name)
.unwrap_or_else(|| panic!("Tool with name {} not found", name))
.clone()]
}
ToolType::Function { function } => {
let tool = req_tools
.iter()
.find(|tool| tool.function.name == function.name)
.unwrap_or_else(|| panic!("Tool with name {} not found", function.name))
.clone();
vec![tool]
// if no tools are provided, we return None
let tools = match tools {
Some(tools) if !tools.is_empty() => tools,
_ => return Ok(None),
};

let tool_choice = tool_choice.0.unwrap_or(ToolType::OneOf);

// if tools are provided and no tool_choice we default to the OneOf
let tools_to_use = match tool_choice {
ToolType::FunctionName(name) => {
vec![Self::find_tool_by_name(&tools, &name)?]
}
ToolType::Function { function } => {
vec![Self::find_tool_by_name(&tools, &function.name)?]
}
ToolType::OneOf => tools,
ToolType::NoTool => return Ok(None),
};

// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
}),
);

let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();

// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};

// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);

// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();

// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);

// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();

// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}
ToolType::OneOf => req_tools.to_owned(),
};

// adds the error notification function for LLM feedback if required
let mut text_response_properties = Map::new();
text_response_properties.insert(
"error".to_string(),
serde_json::json!({
"type": "string",
"description": "The error or issue to notify"
}),
);
text_response_properties.insert(
"_name".to_string(),
(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"type": "string",
"const": "notify_error"
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
);

let functions: HashMap<String, serde_json::Value> = tools_to_use
.iter()
.map(|tool| {
let func = tool.function.clone();

// Clone the existing parameters, which are expected to be a JSON object
let mut params = if let Value::Object(params) = &func.arguments {
params.clone()
} else {
Map::new()
};

// Insert the function's description at the top level, outside of properties
params.insert(
"description".to_string(),
Value::String(func.description.clone().unwrap_or_default()),
);

// Ensure 'properties' exists and is an object
let properties = params
.entry("properties".to_string())
.or_insert_with(|| json!({}))
.as_object_mut()
.unwrap();

// Insert the constant for the function name inside 'properties'
properties.insert(
"_name".to_string(),
json!({
"type": "string",
"const": func.name.clone(),
// "description": "The name of the function"
}),
);

// Check if 'required' exists, and it is an array. If not, create an empty array.
let required = params
.entry("required".to_string())
.or_insert_with(|| json!([]))
.as_array_mut()
.unwrap();

// Add 'name' to the 'required' array if it is not already present
if !required.iter().any(|r| r == "_name") {
required.push(json!("_name"));
}

(func.name, Value::Object(params))
})
.chain([(
"notify_error".to_string(),
serde_json::json!({
"properties": text_response_properties,
"required": ["error", "_name"],
"type": "object"
}),
)])
.collect();

let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};

return Ok(Some(tools));
}
// Err(InferError::ToolError("No tools provided".to_string()))
Ok(None)
)])
.collect();

let tools = Tools {
functions_map: FunctionsMap { functions },
properties: Properties {
function: tools_to_use
.iter()
.map(|tool| FunctionRef {
ref_path: format!("#/$functions/{}", tool.function.name.clone()),
})
.chain(std::iter::once(FunctionRef {
ref_path: "#/$functions/notify_error".to_string(),
}))
.collect(),
},
};

Ok(Some(tools))
}
}

Expand Down
20 changes: 10 additions & 10 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ pub(crate) struct ChatRequest {
/// A specific tool to use. If not provided, the model will default to use any of the tools provided in the tools parameter.
#[serde(default)]
#[schema(nullable = true, example = "null")]
pub tool_choice: Option<ToolType>,
pub tool_choice: ToolChoice,

/// Response format constraints for the generation.
///
Expand All @@ -848,34 +848,34 @@ pub enum ToolType {
OneOf,
FunctionName(String),
Function { function: FunctionName },
NoTool,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ToSchema)]
pub struct FunctionName {
pub name: String,
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, ToSchema)]
#[serde(from = "ToolTypeDeserializer")]
pub struct ToolChoice(pub Option<ToolType>);

#[derive(Deserialize)]
#[serde(untagged)]
enum ToolTypeDeserializer {
None(Option<String>),
Some(ToolType),
String(String),
ToolType(ToolType),
}

impl From<ToolTypeDeserializer> for ToolChoice {
fn from(value: ToolTypeDeserializer) -> Self {
match value {
ToolTypeDeserializer::None(opt) => match opt.as_deref() {
Some("none") => ToolChoice(None),
Some("auto") => ToolChoice(Some(ToolType::OneOf)),
Some(s) => ToolChoice(Some(ToolType::FunctionName(s.to_string()))),
None => ToolChoice(Some(ToolType::OneOf)),
ToolTypeDeserializer::String(s) => match s.as_str() {
"none" => ToolChoice(Some(ToolType::NoTool)),
"auto" => ToolChoice(Some(ToolType::OneOf)),
_ => ToolChoice(Some(ToolType::FunctionName(s))),
},
ToolTypeDeserializer::Some(tool_type) => ToolChoice(Some(tool_type)),
ToolTypeDeserializer::ToolType(tool_type) => ToolChoice(Some(tool_type)),
}
}
}
Expand Down
Loading

0 comments on commit 6bdf8d7

Please sign in to comment.