-
Notifications
You must be signed in to change notification settings - Fork 1
/
Home.py
130 lines (109 loc) · 4.66 KB
/
Home.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import time
import streamlit as st
import dotenv
from PIL import Image
from api.ocr import ocr
from main import OpenAIChatAgentExt, DBApiExt
dotenv.load_dotenv()
if "agent" not in st.session_state:
st.session_state.agent = OpenAIChatAgentExt(os.environ, OpenAIChatAgentExt.ANYSCALE_MODELS[1])
if "db_api" not in st.session_state:
st.session_state.db_api = DBApiExt(os.environ)
if "chat_id" not in st.session_state: st.session_state.chat_id = None
if "disable_input" not in st.session_state: st.session_state.disable_input = False
if "raw_ocr" not in st.session_state: st.session_state.raw_ocr = None
if "pro_ocr" not in st.session_state: st.session_state.pro_ocr = None
if "image" not in st.session_state: st.session_state.image = None
agent = st.session_state.agent
api = st.session_state.db_api
image = st.session_state.image
raw_ocr = st.session_state.raw_ocr
pro_ocr = st.session_state.pro_ocr
api.auth("admin", "admin")
def reset_session():
st.session_state.clear()
st.session_state.db_api = DBApiExt(os.environ)
def load_chat(chat):
st.session_state.agent.message_history = chat.get("message_history")
st.session_state.title = chat.get("title")
st.session_state.chat_id = chat.id
def delete_chat(chat):
if chat.id == st.session_state.chat_id: reset_session()
chat.reference.delete()
def message_stream(text: str):
for ch in text:
yield ch
time.sleep(0.01)
st.set_page_config(page_title="Helt Pro", page_icon="💪🏽")
st.set_option("client.showSidebarNavigation", False)
st.markdown(
"""
<style>
#MainMenu{
visibility: hidden;
}
</style>
""", unsafe_allow_html=True
)
st.title("Health Guard")
st.subheader("LLM-powered Health Assistant")
st.caption("Powered by LLaMA2 and Mistral")
selected_tab = st.sidebar.radio("Navigation", ["Ask", "Preferences"])
if selected_tab == "Ask":
if not raw_ocr:
option = st.radio("Choose an option:", opts := ("Take a picture", "Upload an image", "Manual Entry"), index=1)
if option == opts[0]:
image = st.session_state.image = st.camera_input("Take a picture of the ingredients")
raw_ocr = st.session_state.raw_ocr = NotImplemented
elif option == opts[1]:
image = st.session_state.image = st.file_uploader("Upload an image")
if image is not None:
st.image(image, caption="Uploaded Image", use_column_width=True)
raw_ocr = st.session_state.raw_ocr = ocr(Image.open(image))
else:
image = st.session_state.image = None
raw_ocr = st.session_state.raw_ocr = st.text_area("Enter the ingredients here", height=200)
st.success(raw_ocr)
if raw_ocr:
with agent: pro_ocr = st.session_state.pro_ocr = agent.process_raw_ocr(raw_ocr)
st.rerun()
st.stop()
if image: st.image(image, caption="Uploaded Image", use_column_width=True)
st.success(raw_ocr)
with st.expander("Ingredients"):
st.subheader(pro_ocr.product_name)
for ingredient in pro_ocr.ingredients:
# st.warning(f"***{ingredient}***: {api.fetch(ingredient)}")
st.warning(ingredient)
cnt1 = st.container()
for message in st.session_state.agent.message_history[1:]:
with cnt1.chat_message(message["role"]):
st.write(message["content"])
cnt2 = st.container()
with agent:
if prompt := (st.chat_input(
key="user_input",
placeholder="Type something...",
on_submit=lambda: st.session_state.update({"disable_input": True}),
disabled=st.session_state.disable_input,
)):
st.session_state.disable_input = False
with cnt1.chat_message("user"): st.write(prompt)
with cnt1.chat_message("assistant"):
with st.spinner("Thinking..."): response = agent.chat(prompt,
ingredients=pro_ocr.ingredients,
preferences=api.get_preferences())
st.write_stream(message_stream(response))
st.rerun()
elif selected_tab == "Preferences":
st.write("Manage Your Preferences")
preferences = api.get_preferences()
for i, preference in enumerate(preferences):
col1, col2 = st.columns([20, 1])
col1.success(preference)
col2.button("🗑️", key=preference, on_click=lambda ii=i: api.remove_preference(preferences[ii]))
preference = st.text_input("Add a preference")
if st.button("Add", disabled=not preference):
api.add_preference(preference)
st.rerun()