Skip to content

Commit a5ae069

Browse files
committed
chore: update imports in debugging assistant
feat: add config_path param to model initialization functionality
1 parent d81d7d3 commit a5ae069

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

examples/debugging_assistant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
import streamlit as st
1616
from rai import get_llm_model
17-
from rai.agents.conversational_agent import create_conversational_agent
18-
from rai.frontend.streamlit import run_streamlit_app
17+
from rai.agents import create_conversational_agent
18+
from rai.frontend import run_streamlit_app
1919
from rai.tools.ros2 import ROS2CLIToolkit
2020

2121

src/rai_core/rai/initialization/model_initialization.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,13 @@ class RAIConfig:
8585
tracing: TracingConfig
8686

8787

88-
def load_config() -> RAIConfig:
89-
with open("config.toml", "rb") as f:
90-
config_dict = tomli.load(f)
88+
def load_config(config_path: Optional[str] = None) -> RAIConfig:
89+
if config_path is None:
90+
with open("config.toml", "rb") as f:
91+
config_dict = tomli.load(f)
92+
else:
93+
with open(config_path, "rb") as f:
94+
config_dict = tomli.load(f)
9195
return RAIConfig(
9296
vendor=VendorConfig(**config_dict["vendor"]),
9397
aws=AWSConfig(**config_dict["aws"]),
@@ -104,8 +108,9 @@ def load_config() -> RAIConfig:
104108
def get_llm_model_config_and_vendor(
105109
model_type: Literal["simple_model", "complex_model"],
106110
vendor: Optional[str] = None,
111+
config_path: Optional[str] = None,
107112
) -> Tuple[str, str]:
108-
config = load_config()
113+
config = load_config(config_path)
109114
if vendor is None:
110115
if model_type == "simple_model":
111116
vendor = config.vendor.simple_model
@@ -119,9 +124,12 @@ def get_llm_model_config_and_vendor(
119124
def get_llm_model(
120125
model_type: Literal["simple_model", "complex_model"],
121126
vendor: Optional[str] = None,
127+
config_path: Optional[str] = None,
122128
**kwargs,
123129
):
124-
model_config, vendor = get_llm_model_config_and_vendor(model_type, vendor)
130+
model_config, vendor = get_llm_model_config_and_vendor(
131+
model_type, vendor, config_path
132+
)
125133
model = getattr(model_config, model_type)
126134
logger.info(f"Initializing {model_type}: Vendor: {vendor}, Model: {model}")
127135
if vendor == "openai":
@@ -149,8 +157,8 @@ def get_llm_model(
149157
raise ValueError(f"Unknown LLM vendor: {vendor}")
150158

151159

152-
def get_embeddings_model(vendor: str = None):
153-
config = load_config()
160+
def get_embeddings_model(vendor: str = None, config_path: Optional[str] = None):
161+
config = load_config(config_path)
154162
if vendor is None:
155163
vendor = config.vendor.embeddings_model
156164

0 commit comments

Comments
 (0)