|  | 
| 15 | 15 | import os | 
| 16 | 16 | from typing import Dict, List | 
| 17 | 17 | 
 | 
|  | 18 | +import numpy as np | 
| 18 | 19 | import sounddevice as sd | 
| 19 | 20 | import streamlit as st | 
| 20 | 21 | import tomli | 
| @@ -565,117 +566,143 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in | 
| 565 | 566 |     st.code(toml_string, language="toml") | 
| 566 | 567 | 
 | 
| 567 | 568 |     if st.button("Test Configuration"): | 
| 568 |  | -        success = True | 
| 569 | 569 |         progress = st.progress(0.0) | 
| 570 |  | - | 
| 571 |  | -        vendor = st.session_state.config["vendor"] | 
| 572 |  | -        simple_model_vendor_name = st.session_state.config["vendor"]["simple_model"] | 
| 573 |  | -        complex_model_vendor_name = st.session_state.config["vendor"]["complex_model"] | 
| 574 |  | -        embeddings_model_vendor_name = st.session_state.config["vendor"][ | 
| 575 |  | -            "embeddings_model" | 
| 576 |  | -        ] | 
| 577 |  | - | 
| 578 |  | -        # create simple model | 
| 579 |  | -        progress.progress(0.1) | 
| 580 |  | -        try: | 
| 581 |  | -            if simple_model_vendor_name == "openai": | 
| 582 |  | -                simple_model = ChatOpenAI( | 
| 583 |  | -                    model=st.session_state.config["openai"]["simple_model"] | 
| 584 |  | -                ) | 
| 585 |  | -            elif simple_model_vendor_name == "aws": | 
| 586 |  | -                simple_model = ChatBedrock( | 
| 587 |  | -                    model_id=st.session_state.config["aws"]["simple_model"] | 
| 588 |  | -                ) | 
| 589 |  | -            elif simple_model_vendor_name == "ollama": | 
| 590 |  | -                simple_model = ChatOllama( | 
| 591 |  | -                    model=st.session_state.config["ollama"]["simple_model"], | 
| 592 |  | -                    base_url=st.session_state.config["ollama"]["base_url"], | 
| 593 |  | -                ) | 
| 594 |  | -        except Exception as e: | 
| 595 |  | -            success = False | 
| 596 |  | -            st.error(f"Failed to initialize simple model: {e}") | 
| 597 |  | - | 
| 598 |  | -        # create complex model | 
| 599 |  | -        progress.progress(0.2) | 
| 600 |  | -        try: | 
| 601 |  | -            if complex_model_vendor_name == "openai": | 
| 602 |  | -                complex_model = ChatOpenAI( | 
| 603 |  | -                    model=st.session_state.config["openai"]["complex_model"] | 
| 604 |  | -                ) | 
| 605 |  | -            elif complex_model_vendor_name == "aws": | 
| 606 |  | -                complex_model = ChatBedrock( | 
| 607 |  | -                    model_id=st.session_state.config["aws"]["complex_model"] | 
| 608 |  | -                ) | 
| 609 |  | -            elif complex_model_vendor_name == "ollama": | 
| 610 |  | -                complex_model = ChatOllama( | 
| 611 |  | -                    model=st.session_state.config["ollama"]["complex_model"], | 
|  | 570 | +        test_results = {} | 
|  | 571 | + | 
|  | 572 | +        def create_chat_model(model_type: str): | 
|  | 573 | +            vendor_name = st.session_state.config["vendor"][f"{model_type}_model"] | 
|  | 574 | +            model_name = st.session_state.config[vendor_name][f"{model_type}_model"] | 
|  | 575 | + | 
|  | 576 | +            if vendor_name == "openai": | 
|  | 577 | +                return ChatOpenAI(model=model_name) | 
|  | 578 | +            elif vendor_name == "aws": | 
|  | 579 | +                return ChatBedrock(model_id=model_name) | 
|  | 580 | +            elif vendor_name == "ollama": | 
|  | 581 | +                return ChatOllama( | 
|  | 582 | +                    model=model_name, | 
| 612 | 583 |                     base_url=st.session_state.config["ollama"]["base_url"], | 
| 613 | 584 |                 ) | 
| 614 |  | -        except Exception as e: | 
| 615 |  | -            success = False | 
| 616 |  | -            st.error(f"Failed to initialize complex model: {e}") | 
| 617 |  | - | 
| 618 |  | -        # create embeddings model | 
| 619 |  | -        progress.progress(0.3) | 
| 620 |  | -        try: | 
| 621 |  | -            if embeddings_model_vendor_name == "openai": | 
| 622 |  | -                embeddings_model = OpenAIEmbeddings( | 
| 623 |  | -                    model=st.session_state.config["openai"]["embeddings_model"] | 
| 624 |  | -                ) | 
| 625 |  | -            elif embeddings_model_vendor_name == "aws": | 
| 626 |  | -                embeddings_model = BedrockEmbeddings( | 
| 627 |  | -                    model_id=st.session_state.config["aws"]["embeddings_model"] | 
| 628 |  | -                ) | 
| 629 |  | -            elif embeddings_model_vendor_name == "ollama": | 
| 630 |  | -                embeddings_model = OllamaEmbeddings( | 
| 631 |  | -                    model=st.session_state.config["ollama"]["embeddings_model"], | 
| 632 |  | -                    base_url=st.session_state.config["ollama"]["base_url"], | 
| 633 |  | -                ) | 
| 634 |  | -        except Exception as e: | 
| 635 |  | -            success = False | 
| 636 |  | -            st.error(f"Failed to initialize embeddings model: {e}") | 
| 637 |  | - | 
| 638 |  | -        progress.progress(0.4) | 
| 639 |  | -        use_langfuse = st.session_state.config["tracing"]["langfuse"]["use_langfuse"] | 
| 640 |  | -        if use_langfuse: | 
| 641 |  | -            if not os.getenv("LANGFUSE_SECRET_KEY", "") or not os.getenv( | 
| 642 |  | -                "LANGFUSE_PUBLIC_KEY", "" | 
| 643 |  | -            ): | 
| 644 |  | -                success = False | 
| 645 |  | -                st.error( | 
| 646 |  | -                    "Langfuse is enabled but LANGFUSE_SECRET_KEY or LANGFUSE_PUBLIC_KEY is not set" | 
|  | 585 | +            raise ValueError(f"Unknown vendor: {vendor_name}") | 
|  | 586 | + | 
|  | 587 | +        def test_chat_model(model_type: str) -> bool: | 
|  | 588 | +            try: | 
|  | 589 | +                model = create_chat_model(model_type) | 
|  | 590 | +                answer = model.invoke("Say hello!") | 
|  | 591 | +                return answer.content is not None | 
|  | 592 | +            except Exception as e: | 
|  | 593 | +                st.error(f"{model_type.title()} model error: {e}") | 
|  | 594 | +                return False | 
|  | 595 | + | 
|  | 596 | +        def test_simple_model() -> bool: | 
|  | 597 | +            return test_chat_model("simple") | 
|  | 598 | + | 
|  | 599 | +        def test_complex_model() -> bool: | 
|  | 600 | +            return test_chat_model("complex") | 
|  | 601 | + | 
|  | 602 | +        def test_embeddings_model(): | 
|  | 603 | +            try: | 
|  | 604 | +                embeddings_model_vendor_name = st.session_state.config["vendor"][ | 
|  | 605 | +                    "embeddings_model" | 
|  | 606 | +                ] | 
|  | 607 | +                if embeddings_model_vendor_name == "openai": | 
|  | 608 | +                    embeddings_model = OpenAIEmbeddings( | 
|  | 609 | +                        model=st.session_state.config["openai"]["embeddings_model"] | 
|  | 610 | +                    ) | 
|  | 611 | +                elif embeddings_model_vendor_name == "aws": | 
|  | 612 | +                    embeddings_model = BedrockEmbeddings( | 
|  | 613 | +                        model_id=st.session_state.config["aws"]["embeddings_model"] | 
|  | 614 | +                    ) | 
|  | 615 | +                elif embeddings_model_vendor_name == "ollama": | 
|  | 616 | +                    embeddings_model = OllamaEmbeddings( | 
|  | 617 | +                        model=st.session_state.config["ollama"]["embeddings_model"], | 
|  | 618 | +                        base_url=st.session_state.config["ollama"]["base_url"], | 
|  | 619 | +                    ) | 
|  | 620 | +                embeddings_answer = embeddings_model.embed_query("Say hello!") | 
|  | 621 | +                return embeddings_answer is not None | 
|  | 622 | +            except Exception as e: | 
|  | 623 | +                st.error(f"Embeddings model error: {e}") | 
|  | 624 | +                return False | 
|  | 625 | + | 
|  | 626 | +        def test_langfuse(): | 
|  | 627 | +            use_langfuse = st.session_state.config["tracing"]["langfuse"][ | 
|  | 628 | +                "use_langfuse" | 
|  | 629 | +            ] | 
|  | 630 | +            if not use_langfuse: | 
|  | 631 | +                return True | 
|  | 632 | +            return bool(os.getenv("LANGFUSE_SECRET_KEY")) and bool( | 
|  | 633 | +                os.getenv("LANGFUSE_PUBLIC_KEY") | 
|  | 634 | +            ) | 
|  | 635 | + | 
|  | 636 | +        def test_langsmith(): | 
|  | 637 | +            use_langsmith = st.session_state.config["tracing"]["langsmith"][ | 
|  | 638 | +                "use_langsmith" | 
|  | 639 | +            ] | 
|  | 640 | +            if not use_langsmith: | 
|  | 641 | +                return True | 
|  | 642 | +            return bool(os.getenv("LANGCHAIN_API_KEY")) | 
|  | 643 | + | 
|  | 644 | +        def test_recording_device(index: int, sample_rate: int): | 
|  | 645 | +            try: | 
|  | 646 | +                recording = sd.rec( | 
|  | 647 | +                    device=index, | 
|  | 648 | +                    frames=sample_rate, | 
|  | 649 | +                    samplerate=sample_rate, | 
|  | 650 | +                    channels=1, | 
|  | 651 | +                    dtype="int16", | 
| 647 | 652 |                 ) | 
|  | 653 | +                sd.wait() | 
|  | 654 | +                if np.sum(np.abs(recording)) == 0: | 
|  | 655 | +                    return False | 
|  | 656 | +                return True | 
|  | 657 | +            except Exception as e: | 
|  | 658 | +                st.error(f"Recording device error: {e}") | 
|  | 659 | +                return False | 
|  | 660 | + | 
|  | 661 | +        # Run tests | 
|  | 662 | +        progress.progress(0.2, "Testing simple model...") | 
|  | 663 | +        test_results["Simple Model"] = test_simple_model() | 
|  | 664 | + | 
|  | 665 | +        progress.progress(0.4, "Testing complex model...") | 
|  | 666 | +        test_results["Complex Model"] = test_complex_model() | 
|  | 667 | + | 
|  | 668 | +        progress.progress(0.6, "Testing embeddings model...") | 
|  | 669 | +        test_results["Embeddings Model"] = test_embeddings_model() | 
|  | 670 | + | 
|  | 671 | +        progress.progress(0.8, "Testing tracing...") | 
|  | 672 | +        test_results["Langfuse"] = test_langfuse() | 
|  | 673 | +        test_results["LangSmith"] = test_langsmith() | 
|  | 674 | + | 
|  | 675 | +        progress.progress(0.9, "Testing recording device...") | 
|  | 676 | +        devices = sd.query_devices() | 
|  | 677 | +        device_index = [device["name"] for device in devices].index( | 
|  | 678 | +            st.session_state.config["asr"]["recording_device_name"] | 
|  | 679 | +        ) | 
|  | 680 | +        sample_rate = int(devices[device_index]["default_samplerate"]) | 
|  | 681 | +        test_results["Recording Device"] = test_recording_device( | 
|  | 682 | +            device_index, sample_rate | 
|  | 683 | +        ) | 
|  | 684 | + | 
|  | 685 | +        progress.progress(1.0) | 
|  | 686 | + | 
|  | 687 | +        # Display results in a table | 
|  | 688 | +        st.subheader("Test Results") | 
|  | 689 | + | 
|  | 690 | +        # Create a two-column table using streamlit columns | 
|  | 691 | +        col1, col2 = st.columns(2) | 
|  | 692 | +        with col1: | 
|  | 693 | +            st.markdown("**Component**") | 
|  | 694 | +            for component in test_results.keys(): | 
|  | 695 | +                st.write(component) | 
|  | 696 | +        with col2: | 
|  | 697 | +            st.markdown("**Status**") | 
|  | 698 | +            for result in test_results.values(): | 
|  | 699 | +                st.write("✅ Pass" if result else "❌ Fail") | 
| 648 | 700 | 
 | 
| 649 |  | -        progress.progress(0.5) | 
| 650 |  | -        use_langsmith = st.session_state.config["tracing"]["langsmith"]["use_langsmith"] | 
| 651 |  | -        if use_langsmith: | 
| 652 |  | -            if not os.getenv("LANGCHAIN_API_KEY", ""): | 
| 653 |  | -                success = False | 
| 654 |  | -                st.error("Langsmith is enabled but LANGCHAIN_API_KEY is not set") | 
| 655 |  | - | 
| 656 |  | -        progress.progress(0.6, text="Testing simple model") | 
| 657 |  | -        simple_answer = simple_model.invoke("Say hello!") | 
| 658 |  | -        if simple_answer.content is None: | 
| 659 |  | -            success = False | 
| 660 |  | -            st.error("Simple model is not working") | 
| 661 |  | - | 
| 662 |  | -        progress.progress(0.7, text="Testing complex model") | 
| 663 |  | -        complex_answer = complex_model.invoke("Say hello!") | 
| 664 |  | -        if complex_answer.content is None: | 
| 665 |  | -            success = False | 
| 666 |  | -            st.error("Complex model is not working") | 
| 667 |  | - | 
| 668 |  | -        progress.progress(0.8, text="Testing embeddings model") | 
| 669 |  | -        embeddings_answer = embeddings_model.embed_query("Say hello!") | 
| 670 |  | -        if embeddings_answer is None: | 
| 671 |  | -            success = False | 
| 672 |  | -            st.error("Embeddings model is not working") | 
| 673 |  | - | 
| 674 |  | -        progress.progress(1.0, text="Done!") | 
| 675 |  | -        if success: | 
| 676 |  | -            st.success("Configuration is correct. You can save it now.") | 
|  | 701 | +        # Overall success message | 
|  | 702 | +        if all(test_results.values()): | 
|  | 703 | +            st.success("All tests passed! You can save the configuration now.") | 
| 677 | 704 |         else: | 
| 678 |  | -            st.error("Configuration is incorrect") | 
|  | 705 | +            st.error("Some tests failed. Please check the errors above.") | 
| 679 | 706 | 
 | 
| 680 | 707 |     col1, col2 = st.columns([1, 1]) | 
| 681 | 708 |     with col1: | 
|  | 
0 commit comments