@@ -85,9 +85,13 @@ class RAIConfig:
8585    tracing : TracingConfig 
8686
8787
88- def  load_config () ->  RAIConfig :
89-     with  open ("config.toml" , "rb" ) as  f :
90-         config_dict  =  tomli .load (f )
88+ def  load_config (config_path : Optional [str ] =  None ) ->  RAIConfig :
89+     if  config_path  is  None :
90+         with  open ("config.toml" , "rb" ) as  f :
91+             config_dict  =  tomli .load (f )
92+     else :
93+         with  open (config_path , "rb" ) as  f :
94+             config_dict  =  tomli .load (f )
9195    return  RAIConfig (
9296        vendor = VendorConfig (** config_dict ["vendor" ]),
9397        aws = AWSConfig (** config_dict ["aws" ]),
@@ -104,8 +108,9 @@ def load_config() -> RAIConfig:
104108def  get_llm_model_config_and_vendor (
105109    model_type : Literal ["simple_model" , "complex_model" ],
106110    vendor : Optional [str ] =  None ,
111+     config_path : Optional [str ] =  None ,
107112) ->  Tuple [str , str ]:
108-     config  =  load_config ()
113+     config  =  load_config (config_path )
109114    if  vendor  is  None :
110115        if  model_type  ==  "simple_model" :
111116            vendor  =  config .vendor .simple_model 
@@ -119,9 +124,12 @@ def get_llm_model_config_and_vendor(
119124def  get_llm_model (
120125    model_type : Literal ["simple_model" , "complex_model" ],
121126    vendor : Optional [str ] =  None ,
127+     config_path : Optional [str ] =  None ,
122128    ** kwargs ,
123129):
124-     model_config , vendor  =  get_llm_model_config_and_vendor (model_type , vendor )
130+     model_config , vendor  =  get_llm_model_config_and_vendor (
131+         model_type , vendor , config_path 
132+     )
125133    model  =  getattr (model_config , model_type )
126134    logger .info (f"Initializing { model_type } { vendor } { model }  )
127135    if  vendor  ==  "openai" :
@@ -149,8 +157,8 @@ def get_llm_model(
149157        raise  ValueError (f"Unknown LLM vendor: { vendor }  )
150158
151159
152- def  get_embeddings_model (vendor : str  =  None ):
153-     config  =  load_config ()
160+ def  get_embeddings_model (vendor : str  =  None ,  config_path :  Optional [ str ]  =   None ):
161+     config  =  load_config (config_path )
154162    if  vendor  is  None :
155163        vendor  =  config .vendor .embeddings_model 
156164
0 commit comments