@@ -4,7 +4,7 @@ use base64::prelude::*;
44use super :: { LlmEmbeddingClient , LlmGenerationClient , detect_image_mime_type} ;
55use async_openai:: {
66 Client as OpenAIClient ,
7- config:: OpenAIConfig ,
7+ config:: { AzureConfig , OpenAIConfig } ,
88 types:: {
99 ChatCompletionRequestMessage , ChatCompletionRequestMessageContentPartImage ,
1010 ChatCompletionRequestMessageContentPartText , ChatCompletionRequestSystemMessage ,
@@ -22,13 +22,15 @@ static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
2222 "text-embedding-ada-002" => 1536 ,
2323} ;
2424
25- pub struct Client {
26- client : async_openai:: Client < OpenAIConfig > ,
25+ pub struct Client < C : async_openai :: config :: Config = OpenAIConfig > {
26+ client : async_openai:: Client < C > ,
2727}
2828
2929impl Client {
30- pub ( crate ) fn from_parts ( client : async_openai:: Client < OpenAIConfig > ) -> Self {
31- Self { client }
30+ pub ( crate ) fn from_parts < C : async_openai:: config:: Config > (
31+ client : async_openai:: Client < C > ,
32+ ) -> Client < C > {
33+ Client { client }
3234 }
3335
3436 pub fn new (
@@ -67,6 +69,44 @@ impl Client {
6769 }
6870}
6971
72+ impl Client < AzureConfig > {
73+ pub async fn new_azure (
74+ address : Option < String > ,
75+ api_key : Option < String > ,
76+ api_config : Option < super :: LlmApiConfig > ,
77+ ) -> Result < Self > {
78+ let config = match api_config {
79+ Some ( super :: LlmApiConfig :: AzureOpenAi ( config) ) => config,
80+ Some ( _) => api_bail ! ( "unexpected config type, expected AzureOpenAiConfig" ) ,
81+ None => api_bail ! ( "AzureOpenAiConfig is required for Azure OpenAI" ) ,
82+ } ;
83+
84+ let api_base =
85+ address. ok_or_else ( || anyhow:: anyhow!( "address is required for Azure OpenAI" ) ) ?;
86+
87+ // Default to API version that supports structured outputs (json_schema).
88+ let api_version = config
89+ . api_version
90+ . unwrap_or_else ( || "2024-08-01-preview" . to_string ( ) ) ;
91+
92+ let api_key = api_key
93+ . or_else ( || std:: env:: var ( "AZURE_OPENAI_API_KEY" ) . ok ( ) )
94+ . ok_or_else ( || anyhow:: anyhow!(
95+ "AZURE_OPENAI_API_KEY must be set either via api_key parameter or environment variable"
96+ ) ) ?;
97+
98+ let azure_config = AzureConfig :: new ( )
99+ . with_api_base ( api_base)
100+ . with_api_version ( api_version)
101+ . with_deployment_id ( config. deployment_id )
102+ . with_api_key ( api_key) ;
103+
104+ Ok ( Self {
105+ client : OpenAIClient :: with_config ( azure_config) ,
106+ } )
107+ }
108+ }
109+
70110pub ( super ) fn create_llm_generation_request (
71111 request : & super :: LlmGenerateRequest ,
72112) -> Result < CreateChatCompletionRequest > {
@@ -136,7 +176,10 @@ pub(super) fn create_llm_generation_request(
136176}
137177
138178#[ async_trait]
139- impl LlmGenerationClient for Client {
179+ impl < C > LlmGenerationClient for Client < C >
180+ where
181+ C : async_openai:: config:: Config + Send + Sync ,
182+ {
140183 async fn generate < ' req > (
141184 & self ,
142185 request : super :: LlmGenerateRequest < ' req > ,
@@ -175,7 +218,10 @@ impl LlmGenerationClient for Client {
175218}
176219
177220#[ async_trait]
178- impl LlmEmbeddingClient for Client {
221+ impl < C > LlmEmbeddingClient for Client < C >
222+ where
223+ C : async_openai:: config:: Config + Send + Sync ,
224+ {
179225 async fn embed_text < ' req > (
180226 & self ,
181227 request : super :: LlmEmbeddingRequest < ' req > ,
0 commit comments