Skip to content

Commit e3f47fd

Browse files
authored
Make openai.rs generic to share logic with azureopenai.rs (#1376)
* refactor: make openai.rs generic to support azureconfig * fix: apply rust formatting * refactor: remove azureopenai.rs
1 parent ef1331a commit e3f47fd

File tree

3 files changed

+55
-133
lines changed

3 files changed

+55
-133
lines changed

rust/cocoindex/src/llm/azureopenai.rs

Lines changed: 0 additions & 123 deletions
This file was deleted.

rust/cocoindex/src/llm/mod.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ pub trait LlmEmbeddingClient: Send + Sync {
116116
}
117117

118118
mod anthropic;
119-
mod azureopenai;
120119
mod bedrock;
121120
mod gemini;
122121
mod litellm;
@@ -157,7 +156,7 @@ pub async fn new_llm_generation_client(
157156
as Box<dyn LlmGenerationClient>
158157
}
159158
LlmApiType::AzureOpenAi => {
160-
Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?)
159+
Box::new(openai::Client::new_azure(address, api_key, api_config).await?)
161160
as Box<dyn LlmGenerationClient>
162161
}
163162
LlmApiType::Voyage => {
@@ -196,7 +195,7 @@ pub async fn new_llm_embedding_client(
196195
as Box<dyn LlmEmbeddingClient>
197196
}
198197
LlmApiType::AzureOpenAi => {
199-
Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?)
198+
Box::new(openai::Client::new_azure(address, api_key, api_config).await?)
200199
as Box<dyn LlmEmbeddingClient>
201200
}
202201
LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic | LlmApiType::Bedrock => {

rust/cocoindex/src/llm/openai.rs

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use base64::prelude::*;
44
use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type};
55
use async_openai::{
66
Client as OpenAIClient,
7-
config::OpenAIConfig,
7+
config::{AzureConfig, OpenAIConfig},
88
types::{
99
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
1010
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage,
@@ -22,13 +22,15 @@ static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
2222
"text-embedding-ada-002" => 1536,
2323
};
2424

25-
pub struct Client {
26-
client: async_openai::Client<OpenAIConfig>,
25+
pub struct Client<C: async_openai::config::Config = OpenAIConfig> {
26+
client: async_openai::Client<C>,
2727
}
2828

2929
impl Client {
30-
pub(crate) fn from_parts(client: async_openai::Client<OpenAIConfig>) -> Self {
31-
Self { client }
30+
pub(crate) fn from_parts<C: async_openai::config::Config>(
31+
client: async_openai::Client<C>,
32+
) -> Client<C> {
33+
Client { client }
3234
}
3335

3436
pub fn new(
@@ -67,6 +69,44 @@ impl Client {
6769
}
6870
}
6971

72+
impl Client<AzureConfig> {
73+
pub async fn new_azure(
74+
address: Option<String>,
75+
api_key: Option<String>,
76+
api_config: Option<super::LlmApiConfig>,
77+
) -> Result<Self> {
78+
let config = match api_config {
79+
Some(super::LlmApiConfig::AzureOpenAi(config)) => config,
80+
Some(_) => api_bail!("unexpected config type, expected AzureOpenAiConfig"),
81+
None => api_bail!("AzureOpenAiConfig is required for Azure OpenAI"),
82+
};
83+
84+
let api_base =
85+
address.ok_or_else(|| anyhow::anyhow!("address is required for Azure OpenAI"))?;
86+
87+
// Default to API version that supports structured outputs (json_schema).
88+
let api_version = config
89+
.api_version
90+
.unwrap_or_else(|| "2024-08-01-preview".to_string());
91+
92+
let api_key = api_key
93+
.or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
94+
.ok_or_else(|| anyhow::anyhow!(
95+
"AZURE_OPENAI_API_KEY must be set either via api_key parameter or environment variable"
96+
))?;
97+
98+
let azure_config = AzureConfig::new()
99+
.with_api_base(api_base)
100+
.with_api_version(api_version)
101+
.with_deployment_id(config.deployment_id)
102+
.with_api_key(api_key);
103+
104+
Ok(Self {
105+
client: OpenAIClient::with_config(azure_config),
106+
})
107+
}
108+
}
109+
70110
pub(super) fn create_llm_generation_request(
71111
request: &super::LlmGenerateRequest,
72112
) -> Result<CreateChatCompletionRequest> {
@@ -136,7 +176,10 @@ pub(super) fn create_llm_generation_request(
136176
}
137177

138178
#[async_trait]
139-
impl LlmGenerationClient for Client {
179+
impl<C> LlmGenerationClient for Client<C>
180+
where
181+
C: async_openai::config::Config + Send + Sync,
182+
{
140183
async fn generate<'req>(
141184
&self,
142185
request: super::LlmGenerateRequest<'req>,
@@ -175,7 +218,10 @@ impl LlmGenerationClient for Client {
175218
}
176219

177220
#[async_trait]
178-
impl LlmEmbeddingClient for Client {
221+
impl<C> LlmEmbeddingClient for Client<C>
222+
where
223+
C: async_openai::config::Config + Send + Sync,
224+
{
179225
async fn embed_text<'req>(
180226
&self,
181227
request: super::LlmEmbeddingRequest<'req>,

0 commit comments

Comments
 (0)