Skip to content

Commit c316ce8

Browse files
authored
fix: auto retry explicitly for OpenAI requests (#1172)
1 parent a10c86f commit c316ce8

File tree

4 files changed

+115
-83
lines changed

4 files changed

+115
-83
lines changed

src/execution/live_updater.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ impl SourceUpdateTask {
164164
.next()
165165
.await
166166
.transpose()
167-
.map_err(retryable::Error::always_retryable)
167+
.map_err(retryable::Error::retryable)
168168
},
169169
&retry_options,
170170
)

src/llm/openai.rs

Lines changed: 102 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
use crate::api_bail;
1+
use crate::prelude::*;
2+
use base64::prelude::*;
23

34
use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type};
4-
use anyhow::Result;
55
use async_openai::{
66
Client as OpenAIClient,
77
config::OpenAIConfig,
8+
error::OpenAIError,
89
types::{
910
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
1011
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage,
@@ -14,8 +15,6 @@ use async_openai::{
1415
ResponseFormat, ResponseFormatJsonSchema,
1516
},
1617
};
17-
use async_trait::async_trait;
18-
use base64::prelude::*;
1918
use phf::phf_map;
2019

2120
static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
@@ -62,77 +61,99 @@ impl Client {
6261
}
6362
}
6463

65-
#[async_trait]
66-
impl LlmGenerationClient for Client {
67-
async fn generate<'req>(
68-
&self,
69-
request: super::LlmGenerateRequest<'req>,
70-
) -> Result<super::LlmGenerateResponse> {
71-
let mut messages = Vec::new();
72-
73-
// Add system prompt if provided
74-
if let Some(system) = request.system_prompt {
75-
messages.push(ChatCompletionRequestMessage::System(
76-
ChatCompletionRequestSystemMessage {
77-
content: ChatCompletionRequestSystemMessageContent::Text(system.into_owned()),
78-
..Default::default()
79-
},
80-
));
64+
impl utils::retryable::IsRetryable for OpenAIError {
65+
fn is_retryable(&self) -> bool {
66+
match self {
67+
OpenAIError::Reqwest(e) => e.is_retryable(),
68+
_ => false,
8169
}
70+
}
71+
}
8272

83-
// Add user message
84-
let user_message_content = match request.image {
85-
Some(img_bytes) => {
86-
let base64_image = BASE64_STANDARD.encode(img_bytes.as_ref());
87-
let mime_type = detect_image_mime_type(img_bytes.as_ref())?;
88-
let image_url = format!("data:{mime_type};base64,{base64_image}");
89-
ChatCompletionRequestUserMessageContent::Array(vec![
90-
ChatCompletionRequestUserMessageContentPart::Text(
91-
ChatCompletionRequestMessageContentPartText {
92-
text: request.user_prompt.into_owned(),
93-
},
94-
),
95-
ChatCompletionRequestUserMessageContentPart::ImageUrl(
96-
ChatCompletionRequestMessageContentPartImage {
97-
image_url: async_openai::types::ImageUrl {
98-
url: image_url,
99-
detail: Some(ImageDetail::Auto),
100-
},
101-
},
102-
),
103-
])
104-
}
105-
None => ChatCompletionRequestUserMessageContent::Text(request.user_prompt.into_owned()),
106-
};
107-
messages.push(ChatCompletionRequestMessage::User(
108-
ChatCompletionRequestUserMessage {
109-
content: user_message_content,
73+
fn create_llm_generation_request(
74+
request: &super::LlmGenerateRequest,
75+
) -> Result<CreateChatCompletionRequest> {
76+
let mut messages = Vec::new();
77+
78+
// Add system prompt if provided
79+
if let Some(system) = &request.system_prompt {
80+
messages.push(ChatCompletionRequestMessage::System(
81+
ChatCompletionRequestSystemMessage {
82+
content: ChatCompletionRequestSystemMessageContent::Text(system.to_string()),
11083
..Default::default()
11184
},
11285
));
86+
}
11387

114-
// Create the chat completion request
115-
let request = CreateChatCompletionRequest {
116-
model: request.model.to_string(),
117-
messages,
118-
response_format: match request.output_format {
119-
Some(super::OutputFormat::JsonSchema { name, schema }) => {
120-
Some(ResponseFormat::JsonSchema {
121-
json_schema: ResponseFormatJsonSchema {
122-
name: name.into_owned(),
123-
description: None,
124-
schema: Some(serde_json::to_value(&schema)?),
125-
strict: Some(true),
88+
// Add user message
89+
let user_message_content = match &request.image {
90+
Some(img_bytes) => {
91+
let base64_image = BASE64_STANDARD.encode(img_bytes.as_ref());
92+
let mime_type = detect_image_mime_type(img_bytes.as_ref())?;
93+
let image_url = format!("data:{mime_type};base64,{base64_image}");
94+
ChatCompletionRequestUserMessageContent::Array(vec![
95+
ChatCompletionRequestUserMessageContentPart::Text(
96+
ChatCompletionRequestMessageContentPartText {
97+
text: request.user_prompt.to_string(),
98+
},
99+
),
100+
ChatCompletionRequestUserMessageContentPart::ImageUrl(
101+
ChatCompletionRequestMessageContentPartImage {
102+
image_url: async_openai::types::ImageUrl {
103+
url: image_url,
104+
detail: Some(ImageDetail::Auto),
126105
},
127-
})
128-
}
129-
None => None,
130-
},
106+
},
107+
),
108+
])
109+
}
110+
None => ChatCompletionRequestUserMessageContent::Text(request.user_prompt.to_string()),
111+
};
112+
messages.push(ChatCompletionRequestMessage::User(
113+
ChatCompletionRequestUserMessage {
114+
content: user_message_content,
131115
..Default::default()
132-
};
116+
},
117+
));
118+
// Create the chat completion request
119+
let request = CreateChatCompletionRequest {
120+
model: request.model.to_string(),
121+
messages,
122+
response_format: match &request.output_format {
123+
Some(super::OutputFormat::JsonSchema { name, schema }) => {
124+
Some(ResponseFormat::JsonSchema {
125+
json_schema: ResponseFormatJsonSchema {
126+
name: name.to_string(),
127+
description: None,
128+
schema: Some(serde_json::to_value(&schema)?),
129+
strict: Some(true),
130+
},
131+
})
132+
}
133+
None => None,
134+
},
135+
..Default::default()
136+
};
133137

134-
// Send request and get response
135-
let response = self.client.chat().create(request).await?;
138+
Ok(request)
139+
}
140+
141+
#[async_trait]
142+
impl LlmGenerationClient for Client {
143+
async fn generate<'req>(
144+
&self,
145+
request: super::LlmGenerateRequest<'req>,
146+
) -> Result<super::LlmGenerateResponse> {
147+
let request = &request;
148+
let response = retryable::run(
149+
|| async {
150+
let req = create_llm_generation_request(request)?;
151+
let response = self.client.chat().create(req).await?;
152+
retryable::Ok(response)
153+
},
154+
&retryable::RetryOptions::default(),
155+
)
156+
.await?;
136157

137158
// Extract the response text from the first choice
138159
let text = response
@@ -161,16 +182,21 @@ impl LlmEmbeddingClient for Client {
161182
&self,
162183
request: super::LlmEmbeddingRequest<'req>,
163184
) -> Result<super::LlmEmbeddingResponse> {
164-
let response = self
165-
.client
166-
.embeddings()
167-
.create(CreateEmbeddingRequest {
168-
model: request.model.to_string(),
169-
input: EmbeddingInput::String(request.text.to_string()),
170-
dimensions: request.output_dimension,
171-
..Default::default()
172-
})
173-
.await?;
185+
let response = retryable::run(
186+
|| async {
187+
self.client
188+
.embeddings()
189+
.create(CreateEmbeddingRequest {
190+
model: request.model.to_string(),
191+
input: EmbeddingInput::String(request.text.to_string()),
192+
dimensions: request.output_dimension,
193+
..Default::default()
194+
})
195+
.await
196+
},
197+
&retryable::RetryOptions::default(),
198+
)
199+
.await?;
174200
Ok(super::LlmEmbeddingResponse {
175201
embedding: response
176202
.data

src/ops/targets/neo4j.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,8 +1070,7 @@ impl TargetFactoryBase for Factory {
10701070
},
10711071
&retry_options,
10721072
)
1073-
.await
1074-
.map_err(Into::<anyhow::Error>::into)?
1073+
.await?;
10751074
}
10761075
Ok(())
10771076
}

src/utils/retryable.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ pub trait IsRetryable {
99
}
1010

1111
pub struct Error {
12-
error: anyhow::Error,
13-
is_retryable: bool,
12+
pub error: anyhow::Error,
13+
pub is_retryable: bool,
1414
}
1515

1616
pub const DEFAULT_RETRY_TIMEOUT: Duration = Duration::from_secs(10 * 60);
@@ -40,12 +40,19 @@ impl IsRetryable for reqwest::Error {
4040
}
4141

4242
impl Error {
43-
pub fn always_retryable(error: anyhow::Error) -> Self {
43+
pub fn retryable<E: Into<anyhow::Error>>(error: E) -> Self {
4444
Self {
45-
error,
45+
error: error.into(),
4646
is_retryable: true,
4747
}
4848
}
49+
50+
pub fn not_retryable<E: Into<anyhow::Error>>(error: E) -> Self {
51+
Self {
52+
error: error.into(),
53+
is_retryable: false,
54+
}
55+
}
4956
}
5057

5158
impl From<anyhow::Error> for Error {

0 commit comments

Comments
 (0)