Skip to content

Commit ec32761

Browse files
committed
feat(configurator.py): add TTS testing functionality for ElevenLabs and OpenTTS vendors
refactor(configurator.py): consolidate test execution into a loop for improved readability and maintainability fix(configurator.py): ensure advanced config updates vendor settings correctly chore(configurator.py): add TODOs for future improvements in ASR testing and test file organization
1 parent 5571e49 commit ec32761

File tree

1 file changed

+70
-24
lines changed

1 file changed

+70
-24
lines changed

src/rai/rai/utils/configurator.py

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# limitations under the License.
1414

1515
import os
16+
from functools import partial
1617
from typing import Dict, List
1718

1819
import numpy as np
20+
import requests
1921
import sounddevice as sd
2022
import streamlit as st
2123
import tomli
2224
import tomli_w
25+
from elevenlabs import ElevenLabs
2326
from langchain_aws import BedrockEmbeddings, ChatBedrock
2427
from langchain_ollama import ChatOllama, OllamaEmbeddings
2528
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
@@ -203,18 +206,34 @@ def prev_step():
203206
embeddings_model_vendor = st.text_input(
204207
"Embeddings model vendor", value=current_embeddings_model_vendor
205208
)
206-
if use_advanced_config:
207-
st.session_state.config["vendor"] = {
208-
"simple_model": simple_model_vendor,
209-
"complex_model": complex_model_vendor,
210-
"embeddings_model": embeddings_model_vendor,
211-
}
209+
210+
st.session_state.config["vendor"] = {
211+
"simple_model": simple_model_vendor,
212+
"complex_model": complex_model_vendor,
213+
"embeddings_model": embeddings_model_vendor,
214+
}
212215
else:
213216
st.session_state.config["vendor"] = {
214217
"simple_model": vendor,
215218
"complex_model": vendor,
216219
"embeddings_model": vendor,
217220
}
221+
st.session_state.config["openai"] = {
222+
"simple_model": simple_model,
223+
"complex_model": complex_model,
224+
"embeddings_model": embeddings_model,
225+
}
226+
st.session_state.config["aws"] = {
227+
"simple_model": simple_model,
228+
"complex_model": complex_model,
229+
"embeddings_model": embeddings_model,
230+
}
231+
st.session_state.config["ollama"] = {
232+
"simple_model": simple_model,
233+
"complex_model": complex_model,
234+
"embeddings_model": embeddings_model,
235+
}
236+
218237
# Navigation buttons
219238
col1, col2 = st.columns([1, 1])
220239
with col1:
@@ -641,6 +660,32 @@ def test_langsmith():
641660
return True
642661
return bool(os.getenv("LANGCHAIN_API_KEY"))
643662

663+
def test_tts():
664+
vendor = st.session_state.config["tts"]["vendor"]
665+
if vendor == "elevenlabs":
666+
try:
667+
client = ElevenLabs(api_key=os.getenv("ELEVENLABS_API_KEY"))
668+
output = client.generate(text="Hello, world!")
669+
output = list(output)
670+
return True
671+
except Exception as e:
672+
st.error(f"TTS error: {e}")
673+
return False
674+
elif vendor == "opentts":
675+
try:
676+
params = {
677+
"voice": "glow-speak:en-us_mary_ann",
678+
"text": "Hello, world!",
679+
}
680+
response = requests.get(
681+
"http://localhost:5500/api/tts", params=params
682+
)
683+
if response.status_code == 200:
684+
return True
685+
except Exception as e:
686+
st.error(f"TTS error: {e}")
687+
return False
688+
644689
def test_recording_device(index: int, sample_rate: int):
645690
try:
646691
recording = sd.rec(
@@ -658,31 +703,32 @@ def test_recording_device(index: int, sample_rate: int):
658703
st.error(f"Recording device error: {e}")
659704
return False
660705

661-
# Run tests
662-
progress.progress(0.2, "Testing simple model...")
663-
test_results["Simple Model"] = test_simple_model()
664-
665-
progress.progress(0.4, "Testing complex model...")
666-
test_results["Complex Model"] = test_complex_model()
667-
668-
progress.progress(0.6, "Testing embeddings model...")
669-
test_results["Embeddings Model"] = test_embeddings_model()
706+
# TODO: Add ASR test
707+
# TODO: Move tests to a separate file in tests/
670708

671-
progress.progress(0.8, "Testing tracing...")
672-
test_results["Langfuse"] = test_langfuse()
673-
test_results["LangSmith"] = test_langsmith()
709+
# Run tests
674710

675-
progress.progress(0.9, "Testing recording device...")
676711
devices = sd.query_devices()
677712
device_index = [device["name"] for device in devices].index(
678713
st.session_state.config["asr"]["recording_device_name"]
679714
)
680715
sample_rate = int(devices[device_index]["default_samplerate"])
681-
test_results["Recording Device"] = test_recording_device(
682-
device_index, sample_rate
683-
)
684-
685-
progress.progress(1.0)
716+
tests = [
717+
(test_simple_model, "Simple Model"),
718+
(test_complex_model, "Complex Model"),
719+
(test_embeddings_model, "Embeddings Model"),
720+
(test_langfuse, "Langfuse"),
721+
(test_langsmith, "LangSmith"),
722+
(test_tts, "TTS"),
723+
(
724+
partial(test_recording_device, device_index, sample_rate),
725+
"Recording Device",
726+
),
727+
]
728+
progress.progress(0.0, "Running tests...")
729+
for i, (test, name) in enumerate(tests):
730+
test_results[name] = test()
731+
progress.progress((1 + i) / len(tests), f"Testing {name}...")
686732

687733
# Display results in a table
688734
st.subheader("Test Results")

0 commit comments

Comments
 (0)