1
- use crate :: api_bail;
1
+ use crate :: prelude:: * ;
2
+ use base64:: prelude:: * ;
2
3
3
4
use super :: { LlmEmbeddingClient , LlmGenerationClient , detect_image_mime_type} ;
4
- use anyhow:: Result ;
5
5
use async_openai:: {
6
6
Client as OpenAIClient ,
7
7
config:: OpenAIConfig ,
8
+ error:: OpenAIError ,
8
9
types:: {
9
10
ChatCompletionRequestMessage , ChatCompletionRequestMessageContentPartImage ,
10
11
ChatCompletionRequestMessageContentPartText , ChatCompletionRequestSystemMessage ,
@@ -14,8 +15,6 @@ use async_openai::{
14
15
ResponseFormat , ResponseFormatJsonSchema ,
15
16
} ,
16
17
} ;
17
- use async_trait:: async_trait;
18
- use base64:: prelude:: * ;
19
18
use phf:: phf_map;
20
19
21
20
static DEFAULT_EMBEDDING_DIMENSIONS : phf:: Map < & str , u32 > = phf_map ! {
@@ -62,77 +61,99 @@ impl Client {
62
61
}
63
62
}
64
63
65
- #[ async_trait]
66
- impl LlmGenerationClient for Client {
67
- async fn generate < ' req > (
68
- & self ,
69
- request : super :: LlmGenerateRequest < ' req > ,
70
- ) -> Result < super :: LlmGenerateResponse > {
71
- let mut messages = Vec :: new ( ) ;
72
-
73
- // Add system prompt if provided
74
- if let Some ( system) = request. system_prompt {
75
- messages. push ( ChatCompletionRequestMessage :: System (
76
- ChatCompletionRequestSystemMessage {
77
- content : ChatCompletionRequestSystemMessageContent :: Text ( system. into_owned ( ) ) ,
78
- ..Default :: default ( )
79
- } ,
80
- ) ) ;
64
+ impl utils:: retryable:: IsRetryable for OpenAIError {
65
+ fn is_retryable ( & self ) -> bool {
66
+ match self {
67
+ OpenAIError :: Reqwest ( e) => e. is_retryable ( ) ,
68
+ _ => false ,
81
69
}
70
+ }
71
+ }
82
72
83
- // Add user message
84
- let user_message_content = match request. image {
85
- Some ( img_bytes) => {
86
- let base64_image = BASE64_STANDARD . encode ( img_bytes. as_ref ( ) ) ;
87
- let mime_type = detect_image_mime_type ( img_bytes. as_ref ( ) ) ?;
88
- let image_url = format ! ( "data:{mime_type};base64,{base64_image}" ) ;
89
- ChatCompletionRequestUserMessageContent :: Array ( vec ! [
90
- ChatCompletionRequestUserMessageContentPart :: Text (
91
- ChatCompletionRequestMessageContentPartText {
92
- text: request. user_prompt. into_owned( ) ,
93
- } ,
94
- ) ,
95
- ChatCompletionRequestUserMessageContentPart :: ImageUrl (
96
- ChatCompletionRequestMessageContentPartImage {
97
- image_url: async_openai:: types:: ImageUrl {
98
- url: image_url,
99
- detail: Some ( ImageDetail :: Auto ) ,
100
- } ,
101
- } ,
102
- ) ,
103
- ] )
104
- }
105
- None => ChatCompletionRequestUserMessageContent :: Text ( request. user_prompt . into_owned ( ) ) ,
106
- } ;
107
- messages. push ( ChatCompletionRequestMessage :: User (
108
- ChatCompletionRequestUserMessage {
109
- content : user_message_content,
73
+ fn create_llm_generation_request (
74
+ request : & super :: LlmGenerateRequest ,
75
+ ) -> Result < CreateChatCompletionRequest > {
76
+ let mut messages = Vec :: new ( ) ;
77
+
78
+ // Add system prompt if provided
79
+ if let Some ( system) = & request. system_prompt {
80
+ messages. push ( ChatCompletionRequestMessage :: System (
81
+ ChatCompletionRequestSystemMessage {
82
+ content : ChatCompletionRequestSystemMessageContent :: Text ( system. to_string ( ) ) ,
110
83
..Default :: default ( )
111
84
} ,
112
85
) ) ;
86
+ }
113
87
114
- // Create the chat completion request
115
- let request = CreateChatCompletionRequest {
116
- model : request. model . to_string ( ) ,
117
- messages,
118
- response_format : match request. output_format {
119
- Some ( super :: OutputFormat :: JsonSchema { name, schema } ) => {
120
- Some ( ResponseFormat :: JsonSchema {
121
- json_schema : ResponseFormatJsonSchema {
122
- name : name. into_owned ( ) ,
123
- description : None ,
124
- schema : Some ( serde_json:: to_value ( & schema) ?) ,
125
- strict : Some ( true ) ,
88
+ // Add user message
89
+ let user_message_content = match & request. image {
90
+ Some ( img_bytes) => {
91
+ let base64_image = BASE64_STANDARD . encode ( img_bytes. as_ref ( ) ) ;
92
+ let mime_type = detect_image_mime_type ( img_bytes. as_ref ( ) ) ?;
93
+ let image_url = format ! ( "data:{mime_type};base64,{base64_image}" ) ;
94
+ ChatCompletionRequestUserMessageContent :: Array ( vec ! [
95
+ ChatCompletionRequestUserMessageContentPart :: Text (
96
+ ChatCompletionRequestMessageContentPartText {
97
+ text: request. user_prompt. to_string( ) ,
98
+ } ,
99
+ ) ,
100
+ ChatCompletionRequestUserMessageContentPart :: ImageUrl (
101
+ ChatCompletionRequestMessageContentPartImage {
102
+ image_url: async_openai:: types:: ImageUrl {
103
+ url: image_url,
104
+ detail: Some ( ImageDetail :: Auto ) ,
126
105
} ,
127
- } )
128
- }
129
- None => None ,
130
- } ,
106
+ } ,
107
+ ) ,
108
+ ] )
109
+ }
110
+ None => ChatCompletionRequestUserMessageContent :: Text ( request. user_prompt . to_string ( ) ) ,
111
+ } ;
112
+ messages. push ( ChatCompletionRequestMessage :: User (
113
+ ChatCompletionRequestUserMessage {
114
+ content : user_message_content,
131
115
..Default :: default ( )
132
- } ;
116
+ } ,
117
+ ) ) ;
118
+ // Create the chat completion request
119
+ let request = CreateChatCompletionRequest {
120
+ model : request. model . to_string ( ) ,
121
+ messages,
122
+ response_format : match & request. output_format {
123
+ Some ( super :: OutputFormat :: JsonSchema { name, schema } ) => {
124
+ Some ( ResponseFormat :: JsonSchema {
125
+ json_schema : ResponseFormatJsonSchema {
126
+ name : name. to_string ( ) ,
127
+ description : None ,
128
+ schema : Some ( serde_json:: to_value ( & schema) ?) ,
129
+ strict : Some ( true ) ,
130
+ } ,
131
+ } )
132
+ }
133
+ None => None ,
134
+ } ,
135
+ ..Default :: default ( )
136
+ } ;
133
137
134
- // Send request and get response
135
- let response = self . client . chat ( ) . create ( request) . await ?;
138
+ Ok ( request)
139
+ }
140
+
141
+ #[ async_trait]
142
+ impl LlmGenerationClient for Client {
143
+ async fn generate < ' req > (
144
+ & self ,
145
+ request : super :: LlmGenerateRequest < ' req > ,
146
+ ) -> Result < super :: LlmGenerateResponse > {
147
+ let request = & request;
148
+ let response = retryable:: run (
149
+ || async {
150
+ let req = create_llm_generation_request ( request) ?;
151
+ let response = self . client . chat ( ) . create ( req) . await ?;
152
+ retryable:: Ok ( response)
153
+ } ,
154
+ & retryable:: RetryOptions :: default ( ) ,
155
+ )
156
+ . await ?;
136
157
137
158
// Extract the response text from the first choice
138
159
let text = response
@@ -161,16 +182,21 @@ impl LlmEmbeddingClient for Client {
161
182
& self ,
162
183
request : super :: LlmEmbeddingRequest < ' req > ,
163
184
) -> Result < super :: LlmEmbeddingResponse > {
164
- let response = self
165
- . client
166
- . embeddings ( )
167
- . create ( CreateEmbeddingRequest {
168
- model : request. model . to_string ( ) ,
169
- input : EmbeddingInput :: String ( request. text . to_string ( ) ) ,
170
- dimensions : request. output_dimension ,
171
- ..Default :: default ( )
172
- } )
173
- . await ?;
185
+ let response = retryable:: run (
186
+ || async {
187
+ self . client
188
+ . embeddings ( )
189
+ . create ( CreateEmbeddingRequest {
190
+ model : request. model . to_string ( ) ,
191
+ input : EmbeddingInput :: String ( request. text . to_string ( ) ) ,
192
+ dimensions : request. output_dimension ,
193
+ ..Default :: default ( )
194
+ } )
195
+ . await
196
+ } ,
197
+ & retryable:: RetryOptions :: default ( ) ,
198
+ )
199
+ . await ?;
174
200
Ok ( super :: LlmEmbeddingResponse {
175
201
embedding : response
176
202
. data
0 commit comments