Skip to content

Commit

Permalink
improved chat visualization options
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Aug 21, 2024
1 parent 7955415 commit 82a4879
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 24 deletions.
74 changes: 66 additions & 8 deletions backend/scripts/rag/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import sys
import uuid
from datetime import datetime
from pathlib import Path

import streamlit as st
Expand All @@ -10,6 +8,7 @@
if project_root not in sys.path:
sys.path.append(project_root)

from current import *
from _shared import *
from models.chat import Message
from utils.llm import qa_rag
Expand All @@ -18,6 +17,10 @@
# Initialize session state
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'visualizations' not in st.session_state:
st.session_state.visualizations = {}
if 'contexts' not in st.session_state:
st.session_state.contexts = {}


def add_message(message: Message):
Expand All @@ -38,12 +41,33 @@ def send_message(text: str):
)
add_message(human_message)

# Simulating the AI response (replace this with your actual AI logic)
context_str, memories = retrieve_rag_context(uid, get_messages())
# Retrieve context and generate response
data = retrieve_rag_context(uid, get_messages(), return_context_params=True)
topics, dates_range = [], []

if len(data) == 2:
context_str, memories = data
else:
# noinspection PyTupleAssignmentBalance
context_str, memories, topics, dates_range = data

response: str = qa_rag(context_str, get_messages(), None)

# Generate visualization
ai_message_id = str(uuid.uuid4())
if topics:
file_name = f'{ai_message_id}.html'
generate_topics_visualization(topics, file_name)
visualization_path = os.path.join(project_root, 'scripts', 'rag', file_name)
if os.path.exists(visualization_path):
with open(visualization_path, 'r') as f:
st.session_state.visualizations[ai_message_id] = f.read()

# Store context
st.session_state.contexts[ai_message_id] = context_str

ai_message = Message(
id=str(uuid.uuid4()),
id=ai_message_id,
text=response,
created_at=datetime.utcnow(),
sender='ai',
Expand All @@ -52,19 +76,53 @@ def send_message(text: str):
add_message(ai_message)


# Remove horizontal padding
st.markdown("""
<style>
.block-container {
padding-top: 1rem;
padding-bottom: 0rem;
padding-left: 1rem;
padding-right: 1rem;
}
.main .block-container {
max-width: 100%;
padding-left: 2rem;
padding-right: 2rem;
}
.stChatMessage {
padding-left: 0px;
padding-right: 0px;
}
.stChatMessage .stChatMessageContent {
padding-left: 0.5rem;
padding-right: 0.5rem;
}
</style>
""", unsafe_allow_html=True)

# Streamlit UI
st.title("Simple Chat Application")
st.title("RAG Chat with Embedding Visualization")

# Display chat messages
# Display chat messages with inline visualizations and context
for message in get_messages():
with st.chat_message(message.sender):
st.write(f"{message.sender}: {message.text}")

# Display visualization if available
if message.id in st.session_state.visualizations:
st.components.v1.html(st.session_state.visualizations[message.id], height=400)

# Display context used by AI
if message.id in st.session_state.contexts:
with st.expander("Show Context Used"):
st.text(st.session_state.contexts[message.id])

# Chat input
user_input = st.chat_input("Type your message here...")
if user_input:
send_message(user_input)
st.rerun() # Rerun the app to display the new message

# Display current user ID
st.sidebar.write(f"Current User ID: {uid}")
# st.sidebar.write(f"Current User ID: {uid}")
26 changes: 13 additions & 13 deletions backend/scripts/rag/current.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import uuid
from datetime import datetime
from typing import Dict, Tuple
from typing import Dict

from _shared import *
from models.chat import Message
from utils.llm import determine_requires_context


def _get_mesage(text: str, sender: str):
Expand Down Expand Up @@ -66,14 +65,15 @@ def get_markers(data, data_points, color, name, show_top=None):
)


def visualize():

context: Tuple = determine_requires_context(conversation)
if not context or not context[0]:
print('No context is needed')
return
topics = context[0]
def generate_topics_visualization(topics: List[str], file_path: str = 'embedding_visualization_multi_topic.html'):
# context: Tuple = determine_requires_context(conversation)
# if not context or not context[0]:
# print('No context is needed')
# return
# topics = context[0]
# topics = ['Business', 'Entrepreneurship', 'Failures']
os.makedirs('visualizations/', exist_ok=True)
file_path = os.path.join('visualizations/', file_path)

data = get_data(topics)
all_embeddings = np.array([item['vector'] for item in data.values()])
Expand Down Expand Up @@ -107,14 +107,14 @@ def visualize():
title='Embedding Visualization for Multiple Topics',
xaxis_title='UMAP Dimension 1',
yaxis_title='UMAP Dimension 2',
width=1000,
height=800,
width=800,
height=600,
showlegend=True,
hovermode='closest'
)

generate_html_visualization(fig, file_name='embedding_visualization_multi_topic.html')
generate_html_visualization(fig, file_name=file_path)


if __name__ == '__main__':
visualize()
generate_topics_visualization()
15 changes: 12 additions & 3 deletions backend/utils/retrieval/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def retrieve_memories_for_topics(uid: str, topics: List[str], dates_range: List)

def get_better_memory_chunk(memory: Memory, topics: List[str], context_data: dict) -> str:
print('get_better_memory_chunk', memory.id, topics)
# TODO: prompt should use categories, and be optional, is possible it doesn't return anything.
conversation = TranscriptSegment.segments_as_string(memory.transcript_segments)
if num_tokens_from_string(conversation) < 250:
return Memory.memories_to_string([memory])
Expand All @@ -55,8 +56,11 @@ def get_better_memory_chunk(memory: Memory, topics: List[str], context_data: dic
context_data[memory.id] = chunk


def retrieve_rag_context(uid: str, prev_messages: List[Message]) -> Tuple[str, List[Memory]]:
def retrieve_rag_context(
uid: str, prev_messages: List[Message], return_context_params: bool = False
) -> Tuple[str, List[Memory]]:
requires = requires_context(prev_messages)

if not requires:
return '', []

Expand All @@ -71,6 +75,7 @@ def retrieve_rag_context(uid: str, prev_messages: List[Message]) -> Tuple[str, L

memories_id_to_topics = {}
if topics:
# TODO: Topics for time based, topics should return empty, use categories for topics instead
memories_id_to_topics, memories = retrieve_memories_for_topics(uid, topics, dates_range)
id_counter = Counter(memory['id'] for memory in memories)
memories = sorted(memories, key=lambda x: id_counter[x['id']], reverse=True)
Expand All @@ -86,13 +91,17 @@ def retrieve_rag_context(uid: str, prev_messages: List[Message]) -> Tuple[str, L
context_data = {}
threads = []
for memory in memories:
topics = memories_id_to_topics.get(memory.id, [])
t = threading.Thread(target=get_better_memory_chunk, args=(memory, topics, context_data))
# TODO: if better memory chunk returns empty sometimes, memories are not filtered
m_topics = memories_id_to_topics.get(memory.id, [])
t = threading.Thread(target=get_better_memory_chunk, args=(memory, m_topics, context_data))
threads.append(t)
[t.start() for t in threads]
[t.join() for t in threads]
context_str = '\n'.join(context_data.values()).strip()
else:
context_str = Memory.memories_to_string(memories)

if return_context_params:
return context_str, (memories if context_str else []), topics, dates_range

return context_str, (memories if context_str else [])

0 comments on commit 82a4879

Please sign in to comment.