Skip to content

Commit e5195e7

Browse files
committed
simplify and cache tools
1 parent ff97618 commit e5195e7

File tree

2 files changed

+67
-60
lines changed

2 files changed

+67
-60
lines changed

cli/src/llm_client.rs renamed to cli/src/llm_client_oai.rs

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,27 @@ impl OpenAIClient {
3030
#[async_trait]
3131
impl LlmClient for OpenAIClient {
3232
async fn create_session(&self, config: SessionConfig) -> Result<Box<dyn LlmSession>> {
33+
let cached_tools = config
34+
.tools
35+
.iter()
36+
.map(|t| oai::ChatCompletionTool {
37+
r#type: oai::ChatCompletionToolType::Function,
38+
function: oai::FunctionObject {
39+
name: t.name.clone(),
40+
description: Some(t.description.clone()),
41+
parameters: Some(tool_definition_to_params_json(t)),
42+
strict: Some(true),
43+
},
44+
})
45+
.collect::<Vec<_>>();
46+
3347
Ok(Box::new(OpenAISession {
3448
client: self.client.clone(),
3549
base_url: self.base_url.clone(),
3650
api_key: self.api_key.clone(),
3751
config,
3852
messages: Vec::new(),
53+
cached_tools,
3954
}))
4055
}
4156
}
@@ -45,6 +60,7 @@ struct OpenAISession {
4560
base_url: String,
4661
api_key: String,
4762
config: SessionConfig,
63+
cached_tools: Vec<oai::ChatCompletionTool>,
4864
messages: Vec<oai::ChatCompletionRequestMessage>,
4965
}
5066

@@ -72,6 +88,30 @@ fn tool_definition_to_params_json(tool: &ToolDefinition) -> serde_json::Value {
7288
}
7389

7490
impl OpenAISession {
91+
fn new(client: Client, base_url: String, api_key: String, config: SessionConfig) -> Self {
92+
let cached_tools = config
93+
.tools
94+
.iter()
95+
.map(|t| oai::ChatCompletionTool {
96+
r#type: oai::ChatCompletionToolType::Function,
97+
function: oai::FunctionObject {
98+
name: t.name.clone(),
99+
description: Some(t.description.clone()),
100+
parameters: Some(tool_definition_to_params_json(t)),
101+
strict: None,
102+
},
103+
})
104+
.collect();
105+
Self {
106+
client,
107+
base_url,
108+
api_key,
109+
config,
110+
messages: Vec::new(),
111+
cached_tools,
112+
}
113+
}
114+
75115
fn add_user_message(&mut self, content: String) {
76116
// Add system prompt if this is the first message
77117
if self.messages.is_empty() && self.config.system_prompt.is_some() {
@@ -109,25 +149,10 @@ impl OpenAISession {
109149
async fn complete(&mut self) -> Result<CompletionResult> {
110150
// Build request
111151

112-
let tools = self
113-
.config
114-
.tools
115-
.iter()
116-
.map(|t| oai::ChatCompletionTool {
117-
r#type: oai::ChatCompletionToolType::Function,
118-
function: oai::FunctionObject {
119-
name: t.name.clone(),
120-
description: Some(t.description.clone()),
121-
parameters: Some(tool_definition_to_params_json(&t)),
122-
strict: None,
123-
},
124-
})
125-
.collect();
126-
127152
let request = oai::CreateChatCompletionRequest {
128153
model: self.config.model.clone(),
129154
messages: self.messages.clone(),
130-
tools: Some(tools),
155+
tools: Some(self.cached_tools.clone()),
131156
parallel_tool_calls: Some(true),
132157
..Default::default()
133158
};
@@ -157,57 +182,41 @@ impl OpenAISession {
157182

158183
let choice = completion
159184
.choices
160-
.into_iter()
161-
.next()
185+
.first()
162186
.ok_or_else(|| anyhow!("No choices in response"))?;
163187

164-
let message = choice.message;
188+
let message = &choice.message;
189+
190+
let message_content = message
191+
.content
192+
.clone()
193+
.map(ChatCompletionRequestAssistantMessageContent::Text);
194+
195+
self
196+
.messages
197+
.push(oai::ChatCompletionRequestMessage::Assistant(
198+
oai::ChatCompletionRequestAssistantMessage {
199+
content: message_content,
200+
tool_calls: message.tool_calls.clone(),
201+
..Default::default()
202+
},
203+
));
165204

166205
let result = CompletionResult {
167206
content: message.content.clone(),
168207
tool_calls: message
169208
.tool_calls
170-
.unwrap_or_default()
171-
.into_iter()
209+
.as_ref()
210+
.unwrap_or(&Vec::new())
211+
.iter()
172212
.map(|tc| ToolCall {
173-
id: tc.id,
174-
name: tc.function.name,
213+
id: tc.id.clone(),
214+
name: tc.function.name.clone(),
175215
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
176216
})
177217
.collect(),
178218
};
179219

180-
// Add assistant message to history
181-
self
182-
.messages
183-
.push(oai::ChatCompletionRequestMessage::Assistant(
184-
oai::ChatCompletionRequestAssistantMessage {
185-
content: result
186-
.content
187-
.clone()
188-
.map(ChatCompletionRequestAssistantMessageContent::Text),
189-
tool_calls: if result.tool_calls.is_empty() {
190-
None
191-
} else {
192-
Some(
193-
result
194-
.tool_calls
195-
.iter()
196-
.map(|tc| oai::ChatCompletionMessageToolCall {
197-
id: tc.id.clone(),
198-
r#type: oai::ChatCompletionToolType::Function,
199-
function: oai::FunctionCall {
200-
name: tc.name.clone(),
201-
arguments: tc.arguments.to_string(),
202-
},
203-
})
204-
.collect(),
205-
)
206-
},
207-
..Default::default()
208-
},
209-
));
210-
211220
Ok(result)
212221
}
213222
}
@@ -223,7 +232,7 @@ impl LlmSession for OpenAISession {
223232
self.add_user_message(content);
224233
let result = self.complete().await?;
225234

226-
// Just emit the final completion event
235+
// Just emit the final completion event for now
227236
Ok(Box::new(Box::pin(stream::once(async move {
228237
Ok(LlmEvent::Completion(result))
229238
}))))
@@ -241,7 +250,7 @@ impl LlmSession for OpenAISession {
241250
self.add_tool_results(results);
242251
let result = self.complete().await?;
243252

244-
// Just emit the final completion event
253+
// Just emit the final completion event for now
245254
Ok(Box::new(Box::pin(stream::once(async move {
246255
Ok(LlmEvent::Completion(result))
247256
}))))

cli/src/main.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
mod config;
22
mod executor;
3-
mod llm_client;
3+
mod llm_client_oai;
44
mod llm_client_trait;
55

66
use anyhow::Result;
@@ -11,8 +11,6 @@ use std::path::PathBuf;
1111
use tracing::{debug, error, info};
1212
use tracing_subscriber::EnvFilter;
1313

14-
use crate::config::ToolDefinition;
15-
1614
#[derive(Parser, Debug)]
1715
#[command(author, version, about, long_about = None)]
1816
struct Args {
@@ -41,7 +39,7 @@ async fn main() -> Result<()> {
4139
let api_key = std::env::var("LLM_CLI_TOKEN")?;
4240

4341
// Create LLM client
44-
let client = llm_client::OpenAIClient::new(base_url, api_key);
42+
let client = llm_client_oai::OpenAIClient::new(base_url, api_key);
4543

4644
let model = args
4745
.model

0 commit comments

Comments
 (0)