1515import  logging 
1616import  os 
1717from  dataclasses  import  dataclass 
18- from  typing  import  List , Literal 
18+ from  typing  import  List , Literal ,  cast 
1919
2020import  coloredlogs 
2121import  tomli 
2222from  langchain_core .callbacks .base  import  BaseCallbackHandler 
23+ from  pydantic  import  SecretStr 
2324
2425logger  =  logging .getLogger (__name__ )
2526logger .setLevel (logging .INFO )
@@ -50,6 +51,12 @@ class OllamaConfig(ModelConfig):
5051    base_url : str 
5152
5253
54+ @dataclass  
55+ class  OpenAIConfig (ModelConfig ):
56+     base_url : str 
57+     api_key : str 
58+ 
59+ 
5360@dataclass  
5461class  LangfuseConfig :
5562    use_langfuse : bool 
@@ -72,7 +79,7 @@ class TracingConfig:
7279class  RAIConfig :
7380    vendor : VendorConfig 
7481    aws : AWSConfig 
75-     openai : ModelConfig 
82+     openai : OpenAIConfig 
7683    ollama : OllamaConfig 
7784    tracing : TracingConfig 
7885
@@ -83,7 +90,7 @@ def load_config() -> RAIConfig:
8390    return  RAIConfig (
8491        vendor = VendorConfig (** config_dict ["vendor" ]),
8592        aws = AWSConfig (** config_dict ["aws" ]),
86-         openai = ModelConfig (** config_dict ["openai" ]),
93+         openai = OpenAIConfig (** config_dict ["openai" ]),
8794        ollama = OllamaConfig (** config_dict ["ollama" ]),
8895        tracing = TracingConfig (
8996            project = config_dict ["tracing" ]["project" ],
@@ -110,17 +117,31 @@ def get_llm_model(
110117    if  vendor  ==  "openai" :
111118        from  langchain_openai  import  ChatOpenAI 
112119
113-         return  ChatOpenAI (model = model )
120+         model_config  =  cast (OpenAIConfig , model_config )
121+         api_key  =  (
122+             model_config .api_key 
123+             if  model_config .api_key  !=  "" 
124+             else  os .getenv ("OPENAI_API_KEY" , None )
125+         )
126+         if  api_key  is  None :
127+             raise  ValueError ("OPENAI_API_KEY is not set" )
128+ 
129+         return  ChatOpenAI (
130+             model = model , base_url = model_config .base_url , api_key = SecretStr (api_key )
131+         )
114132    elif  vendor  ==  "aws" :
115133        from  langchain_aws  import  ChatBedrock 
116134
135+         model_config  =  cast (AWSConfig , model_config )
136+ 
117137        return  ChatBedrock (
118138            model_id = model ,
119139            region_name = model_config .region_name ,
120140        )
121141    elif  vendor  ==  "ollama" :
122142        from  langchain_ollama  import  ChatOllama 
123143
144+         model_config  =  cast (OllamaConfig , model_config )
124145        return  ChatOllama (model = model , base_url = model_config .base_url )
125146    else :
126147        raise  ValueError (f"Unknown LLM vendor: { vendor }  )
0 commit comments