Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mod mcp_tool_call;
mod message_history;
mod model_provider_info;
pub mod usage;
pub use model_provider_info::ApiKeyProvider;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
mod models;
Expand Down
43 changes: 41 additions & 2 deletions codex-rs/core/src/model_provider_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::env::VarError;
use std::fmt::Debug;
use std::sync::Arc;
use url::Url;

use crate::error::EnvVarError;

/// Trait for providing API keys dynamically
pub trait ApiKeyProvider: Send + Sync + Debug {
fn get(&self) -> String;
}

/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
/// itself (and a handful of others) additionally expose the more modern
Expand All @@ -30,7 +37,7 @@ pub enum WireApi {
}

/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,
Expand All @@ -45,13 +52,37 @@ pub struct ModelProviderInfo {

/// Which wire protocol this provider expects.
pub wire_api: WireApi,

/// Optional API key provider for dynamic key retrieval
#[serde(skip)]
pub api_key_provider: Option<Arc<dyn ApiKeyProvider>>,
}

impl PartialEq for ModelProviderInfo {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.base_url == other.base_url
&& self.env_key == other.env_key
&& self.env_key_instructions == other.env_key_instructions
&& self.wire_api == other.wire_api
// Skip api_key_provider in comparison since trait objects can't be easily compared
}
}

impl ModelProviderInfo {
/// If `env_key` is Some, returns the API key for this provider if present
/// (and non-empty) in the environment. If `env_key` is required but
/// cannot be found, returns an error.
/// cannot be found, returns an error. If `api_key_provider` is set, uses that instead.
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
// Check if we have an API key provider first
if let Some(provider) = &self.api_key_provider {
let key = provider.get();
if !key.trim().is_empty() {
return Ok(Some(key));
}
}

// Fall back to environment variable approach
match &self.env_key {
Some(env_key) => std::env::var(env_key)
.and_then(|v| {
Expand Down Expand Up @@ -85,6 +116,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: Some("OPENAI_API_KEY".into()),
env_key_instructions: Some("Create an API key (https://platform.openai.com) and export it as an environment variable.".into()),
wire_api: WireApi::Responses,
api_key_provider: None,
},
),
(
Expand All @@ -95,6 +127,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: Some("OPENROUTER_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
api_key_provider: None,
},
),
(
Expand All @@ -105,6 +138,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: Some("GEMINI_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
api_key_provider: None,
},
),
(
Expand All @@ -115,6 +149,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
api_key_provider: None,
},
),
(
Expand All @@ -125,6 +160,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: Some("MISTRAL_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
api_key_provider: None,
},
),
(
Expand All @@ -135,6 +171,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: Some("DEEPSEEK_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
api_key_provider: None,
},
),
(
Expand All @@ -145,6 +182,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: Some("XAI_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
api_key_provider: None,
},
),
(
Expand All @@ -155,6 +193,7 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
env_key: Some("GROQ_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
api_key_provider: None,
},
),
]
Expand Down
Loading