Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/cocoindex/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class EmbedText(op.FunctionSpec):
output_dimension: int | None = None
task_type: str | None = None
api_config: llm.VertexAiConfig | None = None
api_key: str | None = None


class ExtractByLlm(op.FunctionSpec):
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/functions/_engine_builtin_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class EmbedText(op.FunctionSpec):
output_dimension: int | None = None
task_type: str | None = None
api_config: llm.VertexAiConfig | None = None
api_key: str | None = None


class ExtractByLlm(op.FunctionSpec):
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ class LlmSpec:
api_type: LlmApiType
model: str
address: str | None = None
api_key: str | None = None
api_config: VertexAiConfig | OpenAiConfig | None = None
13 changes: 9 additions & 4 deletions src/llm/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@ pub struct Client {
}

impl Client {
pub async fn new(address: Option<String>) -> Result<Self> {
pub async fn new(address: Option<String>, api_key: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Anthropic doesn't support custom API address");
}
let api_key = match std::env::var("ANTHROPIC_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"),

let api_key = if let Some(key) = api_key {
key
} else {
std::env::var("ANTHROPIC_API_KEY").map_err(|_| {
anyhow::anyhow!("ANTHROPIC_API_KEY environment variable must be set")
})?
};

Ok(Self {
api_key,
client: reqwest::Client::new(),
Expand Down
13 changes: 9 additions & 4 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ pub struct AiStudioClient {
}

impl AiStudioClient {
pub fn new(address: Option<String>) -> Result<Self> {
pub fn new(address: Option<String>, api_key: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Gemini doesn't support custom API address");
}
let api_key = match std::env::var("GEMINI_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"),

let api_key = if let Some(key) = api_key {
key
} else {
std::env::var("GEMINI_API_KEY")
.map_err(|_| anyhow::anyhow!("GEMINI_API_KEY environment variable must be set"))?
};

Ok(Self {
api_key,
client: reqwest::Client::new(),
Expand Down Expand Up @@ -271,6 +275,7 @@ static SHARED_RETRY_THROTTLER: LazyLock<SharedRetryThrottler> =
impl VertexAiClient {
pub async fn new(
address: Option<String>,
_api_key: Option<String>,
api_config: Option<super::LlmApiConfig>,
) -> Result<Self> {
if address.is_some() {
Expand Down
9 changes: 7 additions & 2 deletions src/llm/litellm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_litellm(address: Option<String>) -> anyhow::Result<Self> {
pub async fn new_litellm(
address: Option<String>,
api_key: Option<String>,
) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "http://127.0.0.1:4000".to_string());
let api_key = std::env::var("LITELLM_API_KEY").ok();

let api_key = api_key.or_else(|| std::env::var("LITELLM_API_KEY").ok());

let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
Expand Down
48 changes: 26 additions & 22 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pub struct LlmSpec {
pub api_type: LlmApiType,
pub address: Option<String>,
pub model: String,
pub api_key: Option<String>,
pub api_config: Option<LlmApiConfig>,
}

Expand Down Expand Up @@ -117,58 +118,61 @@ mod voyage;
pub async fn new_llm_generation_client(
api_type: LlmApiType,
address: Option<String>,
api_key: Option<String>,
api_config: Option<LlmApiConfig>,
) -> Result<Box<dyn LlmGenerationClient>> {
let client = match api_type {
LlmApiType::Ollama => {
Box::new(ollama::Client::new(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::OpenAi => {
Box::new(openai::Client::new(address, api_config)?) as Box<dyn LlmGenerationClient>
}
LlmApiType::Gemini => {
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmGenerationClient>
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?)
as Box<dyn LlmGenerationClient>,
LlmApiType::Anthropic => {
Box::new(anthropic::Client::new(address).await?) as Box<dyn LlmGenerationClient>
LlmApiType::Gemini => {
Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box<dyn LlmGenerationClient>
}
LlmApiType::LiteLlm => {
Box::new(litellm::Client::new_litellm(address).await?) as Box<dyn LlmGenerationClient>
LlmApiType::VertexAi => {
Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?)
as Box<dyn LlmGenerationClient>
}
LlmApiType::OpenRouter => Box::new(openrouter::Client::new_openrouter(address).await?)
LlmApiType::Anthropic => Box::new(anthropic::Client::new(address, api_key).await?)
as Box<dyn LlmGenerationClient>,
LlmApiType::LiteLlm => Box::new(litellm::Client::new_litellm(address, api_key).await?)
as Box<dyn LlmGenerationClient>,
LlmApiType::OpenRouter => {
Box::new(openrouter::Client::new_openrouter(address, api_key).await?)
as Box<dyn LlmGenerationClient>
}
LlmApiType::Voyage => {
api_bail!("Voyage is not supported for generation")
}
LlmApiType::Vllm => {
Box::new(vllm::Client::new_vllm(address).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::Vllm => Box::new(vllm::Client::new_vllm(address, api_key).await?)
as Box<dyn LlmGenerationClient>,
};
Ok(client)
}

pub async fn new_llm_embedding_client(
api_type: LlmApiType,
address: Option<String>,
api_key: Option<String>,
api_config: Option<LlmApiConfig>,
) -> Result<Box<dyn LlmEmbeddingClient>> {
let client = match api_type {
LlmApiType::Ollama => {
Box::new(ollama::Client::new(address).await?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::Gemini => {
Box::new(gemini::AiStudioClient::new(address)?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::OpenAi => {
Box::new(openai::Client::new(address, api_config)?) as Box<dyn LlmEmbeddingClient>
Box::new(gemini::AiStudioClient::new(address, api_key)?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::OpenAi => Box::new(openai::Client::new(address, api_key, api_config)?)
as Box<dyn LlmEmbeddingClient>,
LlmApiType::Voyage => {
Box::new(voyage::Client::new(address)?) as Box<dyn LlmEmbeddingClient>
Box::new(voyage::Client::new(address, api_key)?) as Box<dyn LlmEmbeddingClient>
}
LlmApiType::VertexAi => {
Box::new(gemini::VertexAiClient::new(address, api_key, api_config).await?)
as Box<dyn LlmEmbeddingClient>
}
LlmApiType::VertexAi => Box::new(gemini::VertexAiClient::new(address, api_config).await?)
as Box<dyn LlmEmbeddingClient>,
LlmApiType::OpenRouter | LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic => {
api_bail!("Embedding is not supported for API type {:?}", api_type)
}
Expand Down
19 changes: 13 additions & 6 deletions src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ impl Client {
Self { client }
}

pub fn new(address: Option<String>, api_config: Option<super::LlmApiConfig>) -> Result<Self> {
pub fn new(
address: Option<String>,
api_key: Option<String>,
api_config: Option<super::LlmApiConfig>,
) -> Result<Self> {
let config = match api_config {
Some(super::LlmApiConfig::OpenAi(config)) => config,
Some(_) => api_bail!("unexpected config type, expected OpenAiConfig"),
Expand All @@ -49,13 +53,16 @@ impl Client {
if let Some(project_id) = config.project_id {
openai_config = openai_config.with_project_id(project_id);
}

// Verify API key is set
if std::env::var("OPENAI_API_KEY").is_err() {
api_bail!("OPENAI_API_KEY environment variable must be set");
if let Some(key) = api_key {
openai_config = openai_config.with_api_key(key);
} else {
// Verify API key is set in environment if not provided in config
if std::env::var("OPENAI_API_KEY").is_err() {
api_bail!("OPENAI_API_KEY environment variable must be set");
}
}

Ok(Self {
// OpenAI client will use OPENAI_API_KEY and OPENAI_API_BASE env variables by default
client: OpenAIClient::with_config(openai_config),
})
}
Expand Down
9 changes: 7 additions & 2 deletions src/llm/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_openrouter(address: Option<String>) -> anyhow::Result<Self> {
pub async fn new_openrouter(
address: Option<String>,
api_key: Option<String>,
) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "https://openrouter.ai/api/v1".to_string());
let api_key = std::env::var("OPENROUTER_API_KEY").ok();

let api_key = api_key.or_else(|| std::env::var("OPENROUTER_API_KEY").ok());

let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
Expand Down
9 changes: 7 additions & 2 deletions src/llm/vllm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ use async_openai::config::OpenAIConfig;
pub use super::openai::Client;

impl Client {
pub async fn new_vllm(address: Option<String>) -> anyhow::Result<Self> {
pub async fn new_vllm(
address: Option<String>,
api_key: Option<String>,
) -> anyhow::Result<Self> {
let address = address.unwrap_or_else(|| "http://127.0.0.1:8000/v1".to_string());
let api_key = std::env::var("VLLM_API_KEY").ok();

let api_key = api_key.or_else(|| std::env::var("VLLM_API_KEY").ok());

let mut config = OpenAIConfig::new().with_api_base(address);
if let Some(api_key) = api_key {
config = config.with_api_key(api_key);
Expand Down
12 changes: 8 additions & 4 deletions src/llm/voyage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,18 @@ pub struct Client {
}

impl Client {
pub fn new(address: Option<String>) -> Result<Self> {
pub fn new(address: Option<String>, api_key: Option<String>) -> Result<Self> {
if address.is_some() {
api_bail!("Voyage AI doesn't support custom API address");
}
let api_key = match std::env::var("VOYAGE_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("VOYAGE_API_KEY environment variable must be set"),

let api_key = if let Some(key) = api_key {
key
} else {
std::env::var("VOYAGE_API_KEY")
.map_err(|_| anyhow::anyhow!("VOYAGE_API_KEY environment variable must be set"))?
};

Ok(Self {
api_key,
client: reqwest::Client::new(),
Expand Down
13 changes: 10 additions & 3 deletions src/ops/functions/embed_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ struct Spec {
api_config: Option<LlmApiConfig>,
output_dimension: Option<u32>,
task_type: Option<String>,
api_key: Option<String>,
}

struct Args {
Expand Down Expand Up @@ -91,9 +92,14 @@ impl SimpleFunctionFactoryBase for Factory {
.next_arg("text")?
.expect_type(&ValueType::Basic(BasicValueType::Str))?
.required()?;
let client =
new_llm_embedding_client(spec.api_type, spec.address.clone(), spec.api_config.clone())
.await?;

let client = new_llm_embedding_client(
spec.api_type,
spec.address.clone(),
spec.api_key.clone(),
spec.api_config.clone(),
)
.await?;
let output_dimension = match spec.output_dimension {
Some(output_dimension) => output_dimension,
None => {
Expand Down Expand Up @@ -144,6 +150,7 @@ mod tests {
api_config: None,
output_dimension: None,
task_type: None,
api_key: None,
};

let factory = Arc::new(Factory);
Expand Down
3 changes: 3 additions & 0 deletions src/ops/functions/extract_by_llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl Executor {
let client = new_llm_generation_client(
spec.llm_spec.api_type,
spec.llm_spec.address,
spec.llm_spec.api_key,
spec.llm_spec.api_config,
)
.await?;
Expand Down Expand Up @@ -204,6 +205,7 @@ mod tests {
api_type: crate::llm::LlmApiType::OpenAi,
model: "gpt-4o".to_string(),
address: None,
api_key: None,
api_config: None,
},
output_type: output_type_spec,
Expand Down Expand Up @@ -274,6 +276,7 @@ mod tests {
api_type: crate::llm::LlmApiType::OpenAi,
model: "gpt-4o".to_string(),
address: None,
api_key: None,
api_config: None,
},
output_type: make_output_type(BasicValueType::Str),
Expand Down