-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstreamlit_app.py
144 lines (116 loc) · 4.9 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
from audio_recorder_streamlit import audio_recorder
import uuid
from core.agent import init_agent, init_chromadb, init_content_embeddings, init_qna_retrieval, REDIS_HOST
from langchain.memory import RedisChatMessageHistory, StreamlitChatMessageHistory
from core.llm_wrapers import LLMChatHandler
from utils import tts, stt
from localization.locales import LOCALES
from PIL import Image
chat_history = []
language = st.radio(
"Мова/Language",
options=["uk", "en"],
index=0,
horizontal=True,
)
st.session_state["language"] = language
st.caption(f"Session: {st.session_state.get('session_id', '')}")
INITIAL_MESSAGE = [
{ # Initial message from the assistant
"id": uuid.uuid4().hex,
"role": "assistant",
"content": LOCALES[language]["hello_assistant"],
},
]
@st.cache_resource
def init_cache():
cached_embedder, chroma_emb_client = init_chromadb()
context_retriever = init_content_embeddings(cached_embedder, chroma_emb_client)
cached_conversational_rqa, llm = init_qna_retrieval(context_retriever, cached_embedder, chroma_emb_client)
agent = init_agent(cached_conversational_rqa, llm)
return agent
AGENT = init_cache()
# Initialize the chat messages history
if "messages" not in st.session_state.keys():
st.session_state["messages"] = []
if "history" not in st.session_state:
st.session_state["history"] = []
if "language" not in st.session_state:
st.session_state["language"] = language
def get_llm_client(session_id):
chat_history = RedisChatMessageHistory(session_id=session_id, url=f"redis://{REDIS_HOST}:6379/2")
chat_handler = LLMChatHandler(AGENT, chat_history)
return chat_handler
def append_message(text, audio=None):
msg_obj = {"role": "user", "content": text, "id": uuid.uuid4().hex}
if audio:
msg_obj["audio"] = audio
st.session_state.messages.append(msg_obj)
try:
response = get_llm_client(st.session_state["session_id"]).send_message(text)
if response == "Agent stopped due to iteration limit or time limit.":
response = "Вибачте, але я не можу відповісти на дане запитання. Зверніться до служби підтримки."
if not response:
response = "Вибачте, але я не можу відповісти на дане запитання."
except:
response = "Вибачте, але я не можу відповісти на дане запитання."
if response.startswith('Накладна'):
st.session_state.messages.append(
{"role": "assistant", "content": response, "id": uuid.uuid4().hex, "image": "invoice.jpg"}
)
else:
st.session_state.messages.append(
{"role": "assistant", "content": response, "id": uuid.uuid4().hex}
)
def build_sidebar():
with open(f"localization/sidebar_{language}.md", "r") as sidebar_file:
sidebar_content = sidebar_file.read()
st.sidebar.markdown(sidebar_content)
# Add a reset button
if st.sidebar.button(LOCALES[language]["reset_chat"]):
for key in st.session_state.keys():
del st.session_state[key]
st.session_state["messages"] = INITIAL_MESSAGE
st.session_state["history"] = []
chat_history.clear()
st.experimental_rerun()
def build_chat():
if prompt := st.chat_input():
append_message(prompt)
# if uploaded_file := st.file_uploader(
# LOCALES[language]["choose_audiofile"], type="wav"
# ):
# st.session_state.messages.append(
# {
# "role": "user",
# "content": "*audio file*",
# "audio": uploaded_file.getvalue(),
# "id": uuid.uuid4().hex,
# }
# )
if audio_bytes := audio_recorder(LOCALES[language]["record_audio"]):
text = stt(audio_bytes, st.session_state["language"])
append_message(text, audio_bytes)
for message in st.session_state.messages:
msg_component = st.chat_message(message["role"])
msg_component.write(message["content"])
if "audio" in message:
msg_component.audio(message["audio"], format="audio/wav")
else:
btn = msg_component.button(
LOCALES[language]["synthesize"], key=message["id"]
)
if btn:
message["audio"] = tts(message["content"], st.session_state["language"])
st.experimental_rerun()
if "image" in message:
msg_component.image(Image.open(message["image"]))
build_sidebar()
if "messages" not in st.session_state.keys() or len(st.session_state["messages"]) == 0:
if st.button(LOCALES[language]["start_chat"]):
st.session_state["messages"] = INITIAL_MESSAGE
st.session_state["session_id"] = uuid.uuid4().hex
st.experimental_rerun()
else:
build_chat()