Skip to content

Commit 32b153e

Browse files
committed
refactor: simplify tests, feat: test asr
1 parent d6dfd03 commit 32b153e

File tree

1 file changed

+132
-105
lines changed

1 file changed

+132
-105
lines changed

src/rai/rai/utils/configurator.py

Lines changed: 132 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
from typing import Dict, List
1717

18+
import numpy as np
1819
import sounddevice as sd
1920
import streamlit as st
2021
import tomli
@@ -565,117 +566,143 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in
565566
st.code(toml_string, language="toml")
566567

567568
if st.button("Test Configuration"):
568-
success = True
569569
progress = st.progress(0.0)
570-
571-
vendor = st.session_state.config["vendor"]
572-
simple_model_vendor_name = st.session_state.config["vendor"]["simple_model"]
573-
complex_model_vendor_name = st.session_state.config["vendor"]["complex_model"]
574-
embeddings_model_vendor_name = st.session_state.config["vendor"][
575-
"embeddings_model"
576-
]
577-
578-
# create simple model
579-
progress.progress(0.1)
580-
try:
581-
if simple_model_vendor_name == "openai":
582-
simple_model = ChatOpenAI(
583-
model=st.session_state.config["openai"]["simple_model"]
584-
)
585-
elif simple_model_vendor_name == "aws":
586-
simple_model = ChatBedrock(
587-
model_id=st.session_state.config["aws"]["simple_model"]
588-
)
589-
elif simple_model_vendor_name == "ollama":
590-
simple_model = ChatOllama(
591-
model=st.session_state.config["ollama"]["simple_model"],
592-
base_url=st.session_state.config["ollama"]["base_url"],
593-
)
594-
except Exception as e:
595-
success = False
596-
st.error(f"Failed to initialize simple model: {e}")
597-
598-
# create complex model
599-
progress.progress(0.2)
600-
try:
601-
if complex_model_vendor_name == "openai":
602-
complex_model = ChatOpenAI(
603-
model=st.session_state.config["openai"]["complex_model"]
604-
)
605-
elif complex_model_vendor_name == "aws":
606-
complex_model = ChatBedrock(
607-
model_id=st.session_state.config["aws"]["complex_model"]
608-
)
609-
elif complex_model_vendor_name == "ollama":
610-
complex_model = ChatOllama(
611-
model=st.session_state.config["ollama"]["complex_model"],
570+
test_results = {}
571+
572+
def create_chat_model(model_type: str):
573+
vendor_name = st.session_state.config["vendor"][f"{model_type}_model"]
574+
model_name = st.session_state.config[vendor_name][f"{model_type}_model"]
575+
576+
if vendor_name == "openai":
577+
return ChatOpenAI(model=model_name)
578+
elif vendor_name == "aws":
579+
return ChatBedrock(model_id=model_name)
580+
elif vendor_name == "ollama":
581+
return ChatOllama(
582+
model=model_name,
612583
base_url=st.session_state.config["ollama"]["base_url"],
613584
)
614-
except Exception as e:
615-
success = False
616-
st.error(f"Failed to initialize complex model: {e}")
617-
618-
# create embeddings model
619-
progress.progress(0.3)
620-
try:
621-
if embeddings_model_vendor_name == "openai":
622-
embeddings_model = OpenAIEmbeddings(
623-
model=st.session_state.config["openai"]["embeddings_model"]
624-
)
625-
elif embeddings_model_vendor_name == "aws":
626-
embeddings_model = BedrockEmbeddings(
627-
model_id=st.session_state.config["aws"]["embeddings_model"]
628-
)
629-
elif embeddings_model_vendor_name == "ollama":
630-
embeddings_model = OllamaEmbeddings(
631-
model=st.session_state.config["ollama"]["embeddings_model"],
632-
base_url=st.session_state.config["ollama"]["base_url"],
633-
)
634-
except Exception as e:
635-
success = False
636-
st.error(f"Failed to initialize embeddings model: {e}")
637-
638-
progress.progress(0.4)
639-
use_langfuse = st.session_state.config["tracing"]["langfuse"]["use_langfuse"]
640-
if use_langfuse:
641-
if not os.getenv("LANGFUSE_SECRET_KEY", "") or not os.getenv(
642-
"LANGFUSE_PUBLIC_KEY", ""
643-
):
644-
success = False
645-
st.error(
646-
"Langfuse is enabled but LANGFUSE_SECRET_KEY or LANGFUSE_PUBLIC_KEY is not set"
585+
raise ValueError(f"Unknown vendor: {vendor_name}")
586+
587+
def test_chat_model(model_type: str) -> bool:
588+
try:
589+
model = create_chat_model(model_type)
590+
answer = model.invoke("Say hello!")
591+
return answer.content is not None
592+
except Exception as e:
593+
st.error(f"{model_type.title()} model error: {e}")
594+
return False
595+
596+
def test_simple_model() -> bool:
597+
return test_chat_model("simple")
598+
599+
def test_complex_model() -> bool:
600+
return test_chat_model("complex")
601+
602+
def test_embeddings_model():
603+
try:
604+
embeddings_model_vendor_name = st.session_state.config["vendor"][
605+
"embeddings_model"
606+
]
607+
if embeddings_model_vendor_name == "openai":
608+
embeddings_model = OpenAIEmbeddings(
609+
model=st.session_state.config["openai"]["embeddings_model"]
610+
)
611+
elif embeddings_model_vendor_name == "aws":
612+
embeddings_model = BedrockEmbeddings(
613+
model_id=st.session_state.config["aws"]["embeddings_model"]
614+
)
615+
elif embeddings_model_vendor_name == "ollama":
616+
embeddings_model = OllamaEmbeddings(
617+
model=st.session_state.config["ollama"]["embeddings_model"],
618+
base_url=st.session_state.config["ollama"]["base_url"],
619+
)
620+
embeddings_answer = embeddings_model.embed_query("Say hello!")
621+
return embeddings_answer is not None
622+
except Exception as e:
623+
st.error(f"Embeddings model error: {e}")
624+
return False
625+
626+
def test_langfuse():
627+
use_langfuse = st.session_state.config["tracing"]["langfuse"][
628+
"use_langfuse"
629+
]
630+
if not use_langfuse:
631+
return True
632+
return bool(os.getenv("LANGFUSE_SECRET_KEY")) and bool(
633+
os.getenv("LANGFUSE_PUBLIC_KEY")
634+
)
635+
636+
def test_langsmith():
637+
use_langsmith = st.session_state.config["tracing"]["langsmith"][
638+
"use_langsmith"
639+
]
640+
if not use_langsmith:
641+
return True
642+
return bool(os.getenv("LANGCHAIN_API_KEY"))
643+
644+
def test_recording_device(index: int, sample_rate: int):
645+
try:
646+
recording = sd.rec(
647+
device=index,
648+
frames=sample_rate,
649+
samplerate=sample_rate,
650+
channels=1,
651+
dtype="int16",
647652
)
653+
sd.wait()
654+
if np.sum(np.abs(recording)) == 0:
655+
return False
656+
return True
657+
except Exception as e:
658+
st.error(f"Recording device error: {e}")
659+
return False
660+
661+
# Run tests
662+
progress.progress(0.2, "Testing simple model...")
663+
test_results["Simple Model"] = test_simple_model()
664+
665+
progress.progress(0.4, "Testing complex model...")
666+
test_results["Complex Model"] = test_complex_model()
667+
668+
progress.progress(0.6, "Testing embeddings model...")
669+
test_results["Embeddings Model"] = test_embeddings_model()
670+
671+
progress.progress(0.8, "Testing tracing...")
672+
test_results["Langfuse"] = test_langfuse()
673+
test_results["LangSmith"] = test_langsmith()
674+
675+
progress.progress(0.9, "Testing recording device...")
676+
devices = sd.query_devices()
677+
device_index = [device["name"] for device in devices].index(
678+
st.session_state.config["asr"]["recording_device_name"]
679+
)
680+
sample_rate = int(devices[device_index]["default_samplerate"])
681+
test_results["Recording Device"] = test_recording_device(
682+
device_index, sample_rate
683+
)
684+
685+
progress.progress(1.0)
686+
687+
# Display results in a table
688+
st.subheader("Test Results")
689+
690+
# Create a two-column table using streamlit columns
691+
col1, col2 = st.columns(2)
692+
with col1:
693+
st.markdown("**Component**")
694+
for component in test_results.keys():
695+
st.write(component)
696+
with col2:
697+
st.markdown("**Status**")
698+
for result in test_results.values():
699+
st.write("✅ Pass" if result else "❌ Fail")
648700

649-
progress.progress(0.5)
650-
use_langsmith = st.session_state.config["tracing"]["langsmith"]["use_langsmith"]
651-
if use_langsmith:
652-
if not os.getenv("LANGCHAIN_API_KEY", ""):
653-
success = False
654-
st.error("Langsmith is enabled but LANGCHAIN_API_KEY is not set")
655-
656-
progress.progress(0.6, text="Testing simple model")
657-
simple_answer = simple_model.invoke("Say hello!")
658-
if simple_answer.content is None:
659-
success = False
660-
st.error("Simple model is not working")
661-
662-
progress.progress(0.7, text="Testing complex model")
663-
complex_answer = complex_model.invoke("Say hello!")
664-
if complex_answer.content is None:
665-
success = False
666-
st.error("Complex model is not working")
667-
668-
progress.progress(0.8, text="Testing embeddings model")
669-
embeddings_answer = embeddings_model.embed_query("Say hello!")
670-
if embeddings_answer is None:
671-
success = False
672-
st.error("Embeddings model is not working")
673-
674-
progress.progress(1.0, text="Done!")
675-
if success:
676-
st.success("Configuration is correct. You can save it now.")
701+
# Overall success message
702+
if all(test_results.values()):
703+
st.success("All tests passed! You can save the configuration now.")
677704
else:
678-
st.error("Configuration is incorrect")
705+
st.error("Some tests failed. Please check the errors above.")
679706

680707
col1, col2 = st.columns([1, 1])
681708
with col1:

0 commit comments

Comments
 (0)