1515import logging
1616import os
1717from dataclasses import dataclass
18- from typing import List , Literal
18+ from typing import List , Literal , cast
1919
2020import coloredlogs
2121import tomli
2222from langchain_core .callbacks .base import BaseCallbackHandler
23+ from pydantic import SecretStr
2324
2425logger = logging .getLogger (__name__ )
2526logger .setLevel (logging .INFO )
@@ -50,6 +51,12 @@ class OllamaConfig(ModelConfig):
5051 base_url : str
5152
5253
54+ @dataclass
55+ class OpenAIConfig (ModelConfig ):
56+ base_url : str
57+ api_key : str
58+
59+
5360@dataclass
5461class LangfuseConfig :
5562 use_langfuse : bool
@@ -72,7 +79,7 @@ class TracingConfig:
7279class RAIConfig :
7380 vendor : VendorConfig
7481 aws : AWSConfig
75- openai : ModelConfig
82+ openai : OpenAIConfig
7683 ollama : OllamaConfig
7784 tracing : TracingConfig
7885
@@ -83,7 +90,7 @@ def load_config() -> RAIConfig:
8390 return RAIConfig (
8491 vendor = VendorConfig (** config_dict ["vendor" ]),
8592 aws = AWSConfig (** config_dict ["aws" ]),
86- openai = ModelConfig (** config_dict ["openai" ]),
93+ openai = OpenAIConfig (** config_dict ["openai" ]),
8794 ollama = OllamaConfig (** config_dict ["ollama" ]),
8895 tracing = TracingConfig (
8996 project = config_dict ["tracing" ]["project" ],
@@ -110,17 +117,31 @@ def get_llm_model(
110117 if vendor == "openai" :
111118 from langchain_openai import ChatOpenAI
112119
113- return ChatOpenAI (model = model )
120+ model_config = cast (OpenAIConfig , model_config )
121+ api_key = (
122+ model_config .api_key
123+ if model_config .api_key != ""
124+ else os .getenv ("OPENAI_API_KEY" , None )
125+ )
126+ if api_key is None :
127+ raise ValueError ("OPENAI_API_KEY is not set" )
128+
129+ return ChatOpenAI (
130+ model = model , base_url = model_config .base_url , api_key = SecretStr (api_key )
131+ )
114132 elif vendor == "aws" :
115133 from langchain_aws import ChatBedrock
116134
135+ model_config = cast (AWSConfig , model_config )
136+
117137 return ChatBedrock (
118138 model_id = model ,
119139 region_name = model_config .region_name ,
120140 )
121141 elif vendor == "ollama" :
122142 from langchain_ollama import ChatOllama
123143
144+ model_config = cast (OllamaConfig , model_config )
124145 return ChatOllama (model = model , base_url = model_config .base_url )
125146 else :
126147 raise ValueError (f"Unknown LLM vendor: { vendor } " )
0 commit comments