Skip to content

Commit

Permalink
user name + facts during chat + QA prompt interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
josancamon19 committed Aug 23, 2024
1 parent 7352c53 commit ec1119a
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 66 deletions.
9 changes: 4 additions & 5 deletions backend/database/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ def get_user_from_uid(uid: str):


def get_user_name(uid: str):
# TODO: "The User" or "User"?
if cached_name := get_cached_user_name(uid):
return cached_name
# if cached_name := get_cached_user_name(uid):
# return cached_name

user = get_user_from_uid(uid)
display_name = user.get('display_name', 'User').split(' ')[0] if user else 'User'
display_name = user.get('display_name', 'User').split(' ')[0] if user else 'The User'
if display_name == 'AnonymousUser':
display_name = 'User'
display_name = 'The User'

cache_user_name(uid, display_name)
return display_name
26 changes: 22 additions & 4 deletions backend/database/facts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import List

from google.cloud import firestore
from google.cloud.firestore_v1 import FieldFilter

from ._client import db


def get_facts(uid: str, limit: int = 100, offset: int = 0):
facts_ref = (
db.collection('users').document(uid).collection('facts')
)
facts_ref = facts_ref.order_by('created_at', direction=firestore.Query.DESCENDING)
# TODO: cache this
facts_ref = db.collection('users').document(uid).collection('facts')
facts_ref = facts_ref.order_by('created_at', direction=firestore.Query.DESCENDING).where(
filter=FieldFilter('deleted', '==', False))
facts_ref = facts_ref.limit(limit).offset(offset)
return [doc.to_dict() for doc in facts_ref.stream()]

Expand All @@ -31,3 +32,20 @@ def delete_facts(uid: str):
for doc in facts_ref.stream():
batch.delete(doc.reference)
batch.commit()


def delete_facts_for_memory(uid: str, memory_id: str):
batch = db.batch()
user_ref = db.collection('users').document(uid)
facts_ref = user_ref.collection('facts')
query = (
facts_ref.where(filter=FieldFilter('memory_id', '==', memory_id))
.where(filter=FieldFilter('deleted', '==', False))
)

removed_ids = []
for doc in query.stream():
batch.update(doc.reference, {'deleted': True})
removed_ids.append(doc.id)
batch.commit()
print('delete_facts_for_memory', memory_id, len(removed_ids))
19 changes: 19 additions & 0 deletions backend/models/facts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import BaseModel, Field

from database._client import document_id_from_seed
from models.memory import CategoryEnum


Expand All @@ -21,6 +22,11 @@ class Fact(BaseModel):
content: str = Field(description="The content of the fact")
category: FactCategory = Field(description="The category of the fact", default=FactCategory.other)

@staticmethod
def get_facts_as_str(facts):
existing_facts = [f"{f.content} ({f.category.value})" for f in facts]
return '' if not existing_facts else '\n- ' + '\n- '.join(existing_facts)


class FactDB(Fact):
id: str
Expand All @@ -37,3 +43,16 @@ class FactDB(Fact):
manually_added: bool = False
edited: bool = False
deleted: bool = False

@staticmethod
def from_fact(fact: Fact, uid: str, memory_id: str, memory_category: CategoryEnum) -> 'FactDB':
return FactDB(
id=document_id_from_seed(fact.content),
uid=uid,
content=fact.content,
category=fact.category,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
memory_id=memory_id,
memory_category=memory_category,
)
14 changes: 11 additions & 3 deletions backend/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from fastapi import APIRouter, Depends

import database.chat as chat_db
from database.auth import get_user_name
from database.facts import get_facts
from models.chat import Message, SendMessageRequest, MessageSender
from utils.other import endpoints as auth
from models.facts import Fact
from utils.llm import qa_rag, initial_chat_message
from utils.other import endpoints as auth
from utils.plugins import get_plugin_by_id
from utils.retrieval.rag import retrieve_rag_context

Expand Down Expand Up @@ -39,7 +42,9 @@ def send_message(
messages = filter_messages(messages, plugin_id)

context_str, memories = retrieve_rag_context(uid, messages)
response: str = qa_rag(context_str, messages, plugin)
user_name = get_user_name(uid)
user_facts = [Fact(**fact) for fact in get_facts(uid)]
response: str = qa_rag(user_name, user_facts, context_str, messages, plugin)

ai_message = Message(
id=str(uuid.uuid4()),
Expand All @@ -59,7 +64,10 @@ def send_message(
def initial_message_util(uid: str, plugin_id: Optional[str] = None):
plugin = get_plugin_by_id(plugin_id)

text = initial_chat_message(plugin)
user_name = get_user_name(uid)
user_facts = [Fact(**fact) for fact in get_facts(uid)]

text = initial_chat_message(user_name, user_facts, plugin)
ai_message = Message(
id=str(uuid.uuid4()),
text=text,
Expand Down
7 changes: 6 additions & 1 deletion backend/scripts/rag/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from models.transcript_segment import TranscriptSegment
from utils.llm import qa_rag
from utils.retrieval.rag import retrieve_rag_context
from database.auth import get_user_name
from database.facts import get_facts
from models.facts import Fact

# File to store the state
STATE_FILE = 'chat_state.json'
Expand Down Expand Up @@ -120,7 +123,9 @@ def send_message(text: str):
else:
context_str, memories, topics, dates_range = data

response: str = qa_rag(context_str, get_messages(), None)
user_name = get_user_name(uid)
user_facts = [Fact(**fact) for fact in get_facts(uid)]
response: str = qa_rag(user_name, user_facts, context_str, get_messages(), None)

# Generate visualization
ai_message_id = str(uuid.uuid4())
Expand Down
24 changes: 5 additions & 19 deletions backend/scripts/rag/facts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import threading
from datetime import datetime
from typing import Tuple

import firebase_admin
Expand All @@ -13,48 +12,35 @@
from utils.llm import new_facts_extractor


def get_preferences_from_memory(memories: List[dict], uid: str) -> List[Tuple[str, List[Fact]]]:
def get_facts_from_memory(memories: List[dict], uid: str) -> List[Tuple[str, List[Fact]]]:
all_facts: List[Tuple[str, List[Fact]]] = []
only_facts: List[Fact] = []
user_name = get_user_name(uid)
print('User:', user_name)
for i, memory in enumerate(memories):
data = Memory(**memory)
try:
new_facts = new_facts_extractor(data.transcript_segments, user_name, only_facts)
except Exception as e:
# LLM failed to parse output, we can skip 1 or 2, every 200.
new_facts = new_facts_extractor(data.transcript_segments, user_name, only_facts)
if not new_facts:
continue
all_facts.append([memory['id'], new_facts])
only_facts.extend(new_facts)

print(uid, 'Memory #', i + 1, 'retrieved', len(new_facts), 'facts')

# for fact in only_facts:
# print(fact.category.value.upper(), '~', fact.content)
return all_facts


def execute_for_user(uid: str):
facts_db.delete_facts(uid)

memories = memories_db.get_memories(uid, limit=2000)
data: List[Tuple[str, List[Fact]]] = get_preferences_from_memory(memories, uid)
data: List[Tuple[str, List[Fact]]] = get_facts_from_memory(memories, uid)
parsed_facts = []
for item in data:
memory_id, facts = item
memory = next((m for m in memories if m['id'] == memory_id), None)
for fact in facts:
parsed_facts.append(FactDB(
id=document_id_from_seed(fact.content),
uid=uid,
content=fact.content,
category=fact.category,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
memory_id=memory_id,
memory_category=memory['structured']['category'],
))
parsed_facts.append(FactDB.from_fact(fact, uid, memory['id'], memory['structured']['category']))
facts_db.save_facts(uid, [fact.dict() for fact in parsed_facts])


Expand Down
46 changes: 29 additions & 17 deletions backend/utils/llm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import tiktoken
import json
from datetime import datetime
from typing import List, Tuple, Optional
from typing import List, Optional

from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -237,13 +236,15 @@ def generate_embedding(content: str) -> List[float]:
# ****************************************
# ************* CHAT BASICS **************
# ****************************************
def initial_chat_message(plugin: Optional[Plugin] = None) -> str:
def initial_chat_message(user_name: str, user_facts: List[Fact], plugin: Optional[Plugin] = None) -> str:
if plugin is None:
prompt = '''
prompt = f'''
You are an AI with the following characteristics:
Name: Friend,
Personality/Description: A friendly and helpful AI assistant that aims to make your life easier and more enjoyable.
Task: Provide assistance, answer questions, and engage in meaningful conversations.
You are made for {user_name}, you already know the following facts about {user_name}: {Fact.get_facts_as_str(user_facts)}.
Send an initial message to start the conversation, make sure this message reflects your personality, \
humor, and characteristics.
Expand All @@ -256,6 +257,8 @@ def initial_chat_message(plugin: Optional[Plugin] = None) -> str:
Name: {plugin.name},
Personality/Description: {plugin.chat_prompt},
Task: {plugin.memory_prompt}
You are made for {user_name}, you already know the following facts about {user_name}: {Fact.get_facts_as_str(user_facts)}.
Send an initial message to start the conversation, make sure this message reflects your personality, \
humor, and characteristics.
Expand Down Expand Up @@ -401,8 +404,6 @@ def new_facts_extractor(
# TODO: later, focus a lot on user said things, rn is hard because of speech profile accuracy
# TODO: include negative facts too? Things the user doesn't like?

existing_facts = [f"{f.content} ({f.category.value})" for f in existing_facts]
facts = '' if not existing_facts else '\n- ' + '\n- '.join(existing_facts)
prompt = f'''
You are an experienced detective, whose job is to create detailed profile personas based on conversations.
Expand All @@ -418,7 +419,7 @@ def new_facts_extractor(
This way we can create a more accurate profile.
Include from 0 up to 3 valuable facts, If you don't find any new facts, or ones worth storing, output an empty list of facts.
Existing Facts: {facts}
Existing Facts: {Fact.get_facts_as_str(existing_facts)}
Conversation:
```
Expand All @@ -427,12 +428,18 @@ def new_facts_extractor(
'''.replace(' ', '').strip()
# print(prompt)

with_parser = llm.with_structured_output(UserFacts)
response: UserFacts = with_parser.invoke(prompt)
return response.facts
try:
with_parser = llm.with_structured_output(UserFacts)
response: UserFacts = with_parser.invoke(prompt)
return response.facts
except Exception as e:
print(f'Error extracting new facts: {e}')
return []


def qa_rag(context: str, messages: List[Message], plugin: Optional[Plugin] = None) -> str:
def qa_rag(
user_name: str, user_facts: List[Fact], context: str, messages: List[Message], plugin: Optional[Plugin] = None
) -> str:
conversation_history = Message.get_messages_as_string(
messages, use_user_name_if_available=True, use_plugin_name_if_available=True
)
Expand All @@ -442,13 +449,16 @@ def qa_rag(context: str, messages: List[Message], plugin: Optional[Plugin] = Non
plugin_info = f"Your name is: {plugin.name}, and your personality/description is '{plugin.description}'.\nMake sure to reflect your personality in your response.\n"

prompt = f"""
You are an assistant for question-answering tasks. Use the following pieces of retrieved context and the conversation history to continue the conversation.
If you don't know the answer, just say that you didn't find any related information or you that don't know. Use three sentences maximum and keep the answer concise.
You are an assistant for question-answering tasks.
You are made for {user_name}, you already know the following facts about {user_name}: {Fact.get_facts_as_str(user_facts)}.
Use what you know about {user_name}, the following pieces of retrieved context and the chat history to continue the chat.
If you don't know the answer, just say that there's no available information about it. Use three sentences maximum and keep the answer concise.
If the message doesn't require context, it will be empty, so follow-up the conversation casually.
If there's not enough information to provide a valuable answer, ask the user for clarification questions.
{plugin_info}
Conversation History:
Chat History:
{conversation_history}
Context:
Expand All @@ -461,11 +471,13 @@ def qa_rag(context: str, messages: List[Message], plugin: Optional[Plugin] = Non
return llm.invoke(prompt).content


def qa_emotional_rag(context: str, memories: List[Memory], emotion: str) -> str:
def qa_emotional_rag(user_name: str, user_facts: List[Fact], context: str, memories: List[Memory], emotion: str) -> str:
conversation_history = Memory.memories_to_string(memories)

prompt = f"""
You are a thoughtful friend. Use the following pieces of retrieved context, the conversation history and user's emotions to share your thoughts and give the user positive advice.
You are a thoughtful friend of {user_name}, you already know the following facts about {user_name}: {Fact.get_facts_as_str(user_facts)}.
Use the following pieces of retrieved context, the conversation history and user's emotions to share your thoughts and give the user positive advice.
Thoughts and positive advice should be like a chat message. Keep it short.
User's emotions:
{emotion}
Expand Down
Loading

0 comments on commit ec1119a

Please sign in to comment.