1313# limitations under the License. 
1414
1515import  os 
16+ from  functools  import  partial 
1617from  typing  import  Dict , List 
1718
1819import  numpy  as  np 
20+ import  requests 
1921import  sounddevice  as  sd 
2022import  streamlit  as  st 
2123import  tomli 
2224import  tomli_w 
25+ from  elevenlabs  import  ElevenLabs 
2326from  langchain_aws  import  BedrockEmbeddings , ChatBedrock 
2427from  langchain_ollama  import  ChatOllama , OllamaEmbeddings 
2528from  langchain_openai  import  ChatOpenAI , OpenAIEmbeddings 
@@ -203,18 +206,34 @@ def prev_step():
203206                    embeddings_model_vendor  =  st .text_input (
204207                        "Embeddings model vendor" , value = current_embeddings_model_vendor 
205208                    )
206-          if   use_advanced_config : 
207-             st .session_state .config ["vendor" ] =  {
208-                 "simple_model" : simple_model_vendor ,
209-                 "complex_model" : complex_model_vendor ,
210-                 "embeddings_model" : embeddings_model_vendor ,
211-             }
209+ 
210+                  st .session_state .config ["vendor" ] =  {
211+                      "simple_model" : simple_model_vendor ,
212+                      "complex_model" : complex_model_vendor ,
213+                      "embeddings_model" : embeddings_model_vendor ,
214+                  }
212215        else :
213216            st .session_state .config ["vendor" ] =  {
214217                "simple_model" : vendor ,
215218                "complex_model" : vendor ,
216219                "embeddings_model" : vendor ,
217220            }
221+         st .session_state .config ["openai" ] =  {
222+             "simple_model" : simple_model ,
223+             "complex_model" : complex_model ,
224+             "embeddings_model" : embeddings_model ,
225+         }
226+         st .session_state .config ["aws" ] =  {
227+             "simple_model" : simple_model ,
228+             "complex_model" : complex_model ,
229+             "embeddings_model" : embeddings_model ,
230+         }
231+         st .session_state .config ["ollama" ] =  {
232+             "simple_model" : simple_model ,
233+             "complex_model" : complex_model ,
234+             "embeddings_model" : embeddings_model ,
235+         }
236+ 
218237        # Navigation buttons 
219238        col1 , col2  =  st .columns ([1 , 1 ])
220239        with  col1 :
@@ -641,6 +660,32 @@ def test_langsmith():
641660                return  True 
642661            return  bool (os .getenv ("LANGCHAIN_API_KEY" ))
643662
663+         def  test_tts ():
664+             vendor  =  st .session_state .config ["tts" ]["vendor" ]
665+             if  vendor  ==  "elevenlabs" :
666+                 try :
667+                     client  =  ElevenLabs (api_key = os .getenv ("ELEVENLABS_API_KEY" ))
668+                     output  =  client .generate (text = "Hello, world!" )
669+                     output  =  list (output )
670+                     return  True 
671+                 except  Exception  as  e :
672+                     st .error (f"TTS error: { e }  )
673+                 return  False 
674+             elif  vendor  ==  "opentts" :
675+                 try :
676+                     params  =  {
677+                         "voice" : "glow-speak:en-us_mary_ann" ,
678+                         "text" : "Hello, world!" ,
679+                     }
680+                     response  =  requests .get (
681+                         "http://localhost:5500/api/tts" , params = params 
682+                     )
683+                     if  response .status_code  ==  200 :
684+                         return  True 
685+                 except  Exception  as  e :
686+                     st .error (f"TTS error: { e }  )
687+                 return  False 
688+ 
644689        def  test_recording_device (index : int , sample_rate : int ):
645690            try :
646691                recording  =  sd .rec (
@@ -658,31 +703,32 @@ def test_recording_device(index: int, sample_rate: int):
658703                st .error (f"Recording device error: { e }  )
659704                return  False 
660705
661-         # Run tests 
662-         progress .progress (0.2 , "Testing simple model..." )
663-         test_results ["Simple Model" ] =  test_simple_model ()
664- 
665-         progress .progress (0.4 , "Testing complex model..." )
666-         test_results ["Complex Model" ] =  test_complex_model ()
667- 
668-         progress .progress (0.6 , "Testing embeddings model..." )
669-         test_results ["Embeddings Model" ] =  test_embeddings_model ()
706+         # TODO: Add ASR test 
707+         # TODO: Move tests to a separate file in tests/ 
670708
671-         progress .progress (0.8 , "Testing tracing..." )
672-         test_results ["Langfuse" ] =  test_langfuse ()
673-         test_results ["LangSmith" ] =  test_langsmith ()
709+         # Run tests 
674710
675-         progress .progress (0.9 , "Testing recording device..." )
676711        devices  =  sd .query_devices ()
677712        device_index  =  [device ["name" ] for  device  in  devices ].index (
678713            st .session_state .config ["asr" ]["recording_device_name" ]
679714        )
680715        sample_rate  =  int (devices [device_index ]["default_samplerate" ])
681-         test_results ["Recording Device" ] =  test_recording_device (
682-             device_index , sample_rate 
683-         )
684- 
685-         progress .progress (1.0 )
716+         tests  =  [
717+             (test_simple_model , "Simple Model" ),
718+             (test_complex_model , "Complex Model" ),
719+             (test_embeddings_model , "Embeddings Model" ),
720+             (test_langfuse , "Langfuse" ),
721+             (test_langsmith , "LangSmith" ),
722+             (test_tts , "TTS" ),
723+             (
724+                 partial (test_recording_device , device_index , sample_rate ),
725+                 "Recording Device" ,
726+             ),
727+         ]
728+         progress .progress (0.0 , "Running tests..." )
729+         for  i , (test , name ) in  enumerate (tests ):
730+             test_results [name ] =  test ()
731+             progress .progress ((1  +  i ) /  len (tests ), f"Testing { name }  )
686732
687733        # Display results in a table 
688734        st .subheader ("Test Results" )
0 commit comments