Skip to content

Commit 2317a65

Browse files
committed
feat: allow passing base_url and api_key in openai section of config.toml
1 parent 53b457f commit 2317a65

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

config.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ region_name = "us-east-1"
1313
simple_model = "gpt-4o-mini"
1414
complex_model = "gpt-4o-2024-08-06"
1515
embeddings_model = "text-embedding-ada-002"
16+
base_url = "https://api.openai.com/v1/" # for openai compatible apis
17+
api_key = "" # for openai compatible apis, overrides the environment variable OPENAI_API_KEY
1618

1719
[ollama]
1820
simple_model = "llama3.2"

src/rai/rai/utils/model_initialization.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
import logging
1616
import os
1717
from dataclasses import dataclass
18-
from typing import List, Literal
18+
from typing import List, Literal, cast
1919

2020
import coloredlogs
2121
import tomli
2222
from langchain_core.callbacks.base import BaseCallbackHandler
23+
from pydantic import SecretStr
2324

2425
logger = logging.getLogger(__name__)
2526
logger.setLevel(logging.INFO)
@@ -50,6 +51,12 @@ class OllamaConfig(ModelConfig):
5051
base_url: str
5152

5253

54+
@dataclass
55+
class OpenAIConfig(ModelConfig):
56+
base_url: str
57+
api_key: str
58+
59+
5360
@dataclass
5461
class LangfuseConfig:
5562
use_langfuse: bool
@@ -72,7 +79,7 @@ class TracingConfig:
7279
class RAIConfig:
7380
vendor: VendorConfig
7481
aws: AWSConfig
75-
openai: ModelConfig
82+
openai: OpenAIConfig
7683
ollama: OllamaConfig
7784
tracing: TracingConfig
7885

@@ -83,7 +90,7 @@ def load_config() -> RAIConfig:
8390
return RAIConfig(
8491
vendor=VendorConfig(**config_dict["vendor"]),
8592
aws=AWSConfig(**config_dict["aws"]),
86-
openai=ModelConfig(**config_dict["openai"]),
93+
openai=OpenAIConfig(**config_dict["openai"]),
8794
ollama=OllamaConfig(**config_dict["ollama"]),
8895
tracing=TracingConfig(
8996
project=config_dict["tracing"]["project"],
@@ -110,17 +117,31 @@ def get_llm_model(
110117
if vendor == "openai":
111118
from langchain_openai import ChatOpenAI
112119

113-
return ChatOpenAI(model=model)
120+
model_config = cast(OpenAIConfig, model_config)
121+
api_key = (
122+
model_config.api_key
123+
if model_config.api_key != ""
124+
else os.getenv("OPENAI_API_KEY", None)
125+
)
126+
if api_key is None:
127+
raise ValueError("OPENAI_API_KEY is not set")
128+
129+
return ChatOpenAI(
130+
model=model, base_url=model_config.base_url, api_key=SecretStr(api_key)
131+
)
114132
elif vendor == "aws":
115133
from langchain_aws import ChatBedrock
116134

135+
model_config = cast(AWSConfig, model_config)
136+
117137
return ChatBedrock(
118138
model_id=model,
119139
region_name=model_config.region_name,
120140
)
121141
elif vendor == "ollama":
122142
from langchain_ollama import ChatOllama
123143

144+
model_config = cast(OllamaConfig, model_config)
124145
return ChatOllama(model=model, base_url=model_config.base_url)
125146
else:
126147
raise ValueError(f"Unknown LLM vendor: {vendor}")

0 commit comments

Comments
 (0)