Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# LLM Session Memory"
"# LLM Message History"
]
},
{
Expand All @@ -15,7 +15,7 @@
"\n",
"The solution to this problem is to append the previous conversation history to each subsequent call to the LLM.\n",
"\n",
"This notebook will show how to use Redis to structure and store and retrieve this conversational session memory."
"This notebook will show how to use Redis to structure and store and retrieve this conversational message history."
]
},
{
Expand All @@ -32,8 +32,8 @@
}
],
"source": [
"from redisvl.extensions.session_manager import StandardSessionManager\n",
"chat_session = StandardSessionManager(name='student tutor')"
"from redisvl.extensions.message_history import MessageHistory\n",
"chat_history = MessageHistory(name='student tutor')"
]
},
{
Expand All @@ -52,8 +52,8 @@
"metadata": {},
"outputs": [],
"source": [
"chat_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful geography tutor, giving simple and short answers to questions about Europen countries.\"})\n",
"chat_session.add_messages([\n",
"chat_history.add_message({\"role\":\"system\", \"content\":\"You are a helpful geography tutor, giving simple and short answers to questions about European countries.\"})\n",
"chat_history.add_messages([\n",
" {\"role\":\"user\", \"content\":\"What is the capital of France?\"},\n",
" {\"role\":\"llm\", \"content\":\"The capital is Paris.\"},\n",
" {\"role\":\"user\", \"content\":\"And what is the capital of Spain?\"},\n",
Expand Down Expand Up @@ -88,7 +88,7 @@
}
],
"source": [
"context = chat_session.get_recent()\n",
"context = chat_history.get_recent()\n",
"for message in context:\n",
" print(message)"
]
Expand All @@ -97,7 +97,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In many LLM flows the conversation progresses in a series of prompt and response pairs. session managers offer a convienience function `store()` to add these simply."
"In many LLM flows the conversation progresses in a series of prompt and response pairs. Message history offer a convenience function `store()` to add these simply."
]
},
{
Expand All @@ -121,9 +121,9 @@
"source": [
"prompt = \"what is the size of England compared to Portugal?\"\n",
"response = \"England is larger in land area than Portal by about 15000 square miles.\"\n",
"chat_session.store(prompt, response)\n",
"chat_history.store(prompt, response)\n",
"\n",
"context = chat_session.get_recent(top_k=6)\n",
"context = chat_history.get_recent(top_k=6)\n",
"for message in context:\n",
" print(message)"
]
Expand Down Expand Up @@ -160,33 +160,33 @@
}
],
"source": [
"chat_session.add_message({\"role\":\"system\", \"content\":\"You are a helpful algebra tutor, giving simple answers to math problems.\"}, session_tag='student two')\n",
"chat_session.add_messages([\n",
"chat_history.add_message({\"role\":\"system\", \"content\":\"You are a helpful algebra tutor, giving simple answers to math problems.\"}, session_tag='student two')\n",
"chat_history.add_messages([\n",
" {\"role\":\"user\", \"content\":\"What is the value of x in the equation 2x + 3 = 7?\"},\n",
" {\"role\":\"llm\", \"content\":\"The value of x is 2.\"},\n",
" {\"role\":\"user\", \"content\":\"What is the value of y in the equation 3y - 5 = 7?\"},\n",
" {\"role\":\"llm\", \"content\":\"The value of y is 4.\"}],\n",
" session_tag='student two'\n",
" )\n",
"\n",
"for math_message in chat_session.get_recent(session_tag='student two'):\n",
"for math_message in chat_history.get_recent(session_tag='student two'):\n",
" print(math_message)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Semantic conversation memory\n",
"## Semantic message history\n",
"For longer conversations our list of messages keeps growing. Since LLMs are stateless we have to continue to pass this conversation history on each subsequent call to ensure the LLM has the correct context.\n",
"\n",
"A typical flow looks like this:\n",
"```\n",
"while True:\n",
" prompt = input('enter your next question')\n",
" context = chat_session.get_recent()\n",
" context = chat_history.get_recent()\n",
" response = LLM_api_call(prompt=prompt, context=context)\n",
" chat_session.store(prompt, response)\n",
" chat_history.store(prompt, response)\n",
"```\n",
"\n",
"This works, but as context keeps growing so too does our LLM token count, which increases latency and cost.\n",
Expand All @@ -195,12 +195,12 @@
"\n",
"A better solution is to pass only the relevant conversational context on each subsequent call.\n",
"\n",
"For this, RedisVL has the `SemanticSessionManager`, which uses vector similarity search to return only semantically relevant sections of the conversation."
"For this, RedisVL has the `SemanticMessageHistory`, which uses vector similarity search to return only semantically relevant sections of the conversation."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -212,10 +212,10 @@
}
],
"source": [
"from redisvl.extensions.session_manager import SemanticSessionManager\n",
"semantic_session = SemanticSessionManager(name='tutor')\n",
"from redisvl.extensions.message_history import SemanticMessageHistory\n",
"semantic_history = SemanticMessageHistory(name='tutor')\n",
"\n",
"semantic_session.add_messages(chat_session.get_recent(top_k=8))"
"semantic_history.add_messages(chat_history.get_recent(top_k=8))"
]
},
{
Expand All @@ -234,8 +234,8 @@
],
"source": [
"prompt = \"what have I learned about the size of England?\"\n",
"semantic_session.set_distance_threshold(0.35)\n",
"context = semantic_session.get_relevant(prompt)\n",
"semantic_history.set_distance_threshold(0.35)\n",
"context = semantic_history.get_relevant(prompt)\n",
"for message in context:\n",
" print(message)"
]
Expand Down Expand Up @@ -266,9 +266,9 @@
}
],
"source": [
"semantic_session.set_distance_threshold(0.7)\n",
"semantic_history.set_distance_threshold(0.7)\n",
"\n",
"larger_context = semantic_session.get_relevant(prompt)\n",
"larger_context = semantic_history.get_relevant(prompt)\n",
"for message in larger_context:\n",
" print(message)"
]
Expand Down Expand Up @@ -300,17 +300,17 @@
}
],
"source": [
"semantic_session.store(\n",
"semantic_history.store(\n",
" prompt=\"what is the smallest country in Europe?\",\n",
" response=\"Monaco is the smallest country in Europe at 0.78 square miles.\" # Incorrect. Vatican City is the smallest country in Europe\n",
" )\n",
"\n",
"# get the key of the incorrect message\n",
"context = semantic_session.get_recent(top_k=1, raw=True)\n",
"context = semantic_history.get_recent(top_k=1, raw=True)\n",
"bad_key = context[0]['entry_id']\n",
"semantic_session.drop(bad_key)\n",
"semantic_history.drop(bad_key)\n",
"\n",
"corrected_context = semantic_session.get_recent()\n",
"corrected_context = semantic_history.get_recent()\n",
"for message in corrected_context:\n",
" print(message)"
]
Expand All @@ -321,7 +321,8 @@
"metadata": {},
"outputs": [],
"source": [
"chat_session.clear()"
"chat_history.clear()\n",
"semantic_history.clear()"
]
}
],
Expand Down
10 changes: 5 additions & 5 deletions redisvl/extensions/constants.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
"""
Constants used within the extension classes SemanticCache, BaseSessionManager,
StandardSessionManager,SemanticSessionManager and SemanticRouter.
Constants used within the extension classes SemanticCache, BaseMessageHistory,
MessageHistory, SemanticMessageHistory and SemanticRouter.
These constants are also used within theses classes corresponding schema.
"""

# BaseSessionManager
# BaseMessageHistory
ID_FIELD_NAME: str = "entry_id"
ROLE_FIELD_NAME: str = "role"
CONTENT_FIELD_NAME: str = "content"
TOOL_FIELD_NAME: str = "tool_call_id"
TIMESTAMP_FIELD_NAME: str = "timestamp"
SESSION_FIELD_NAME: str = "session_tag"

# SemanticSessionManager
SESSION_VECTOR_FIELD_NAME: str = "vector_field"
# SemanticMessageHistory
MESSAGE_VECTOR_FIELD_NAME: str = "vector_field"

# SemanticCache
REDIS_KEY_FIELD_NAME: str = "key"
Expand Down
5 changes: 5 additions & 0 deletions redisvl/extensions/message_history/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from redisvl.extensions.message_history.base_history import BaseMessageHistory
from redisvl.extensions.message_history.message_history import MessageHistory
from redisvl.extensions.message_history.semantic_history import SemanticMessageHistory

__all__ = ["BaseMessageHistory", "MessageHistory", "SemanticMessageHistory"]
157 changes: 157 additions & 0 deletions redisvl/extensions/message_history/base_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Any, Dict, List, Optional, Union

from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
ROLE_FIELD_NAME,
TOOL_FIELD_NAME,
)
from redisvl.extensions.message_history.schema import ChatMessage
from redisvl.utils.utils import create_ulid


class BaseMessageHistory:

def __init__(
self,
name: str,
session_tag: Optional[str] = None,
):
"""Initialize message history with index

Message History stores the current and previous user text prompts and
LLM responses to allow for enriching future prompts with session
context. Message history is stored in individual user or LLM prompts and
responses.

Args:
name (str): The name of the message history index.
session_tag (str): Tag to be added to entries to link to a specific
conversation session. Defaults to instance ULID.
"""
self._name = name
self._session_tag = session_tag or create_ulid()

def clear(self) -> None:
"""Clears the chat message history."""
raise NotImplementedError

def delete(self) -> None:
"""Clear all conversation history and remove any search indices."""
raise NotImplementedError

def drop(self, id_field: Optional[str] = None) -> None:
"""Remove a specific exchange from the conversation history.

Args:
id_field (Optional[str]): The id_field of the entry to delete.
If None then the last entry is deleted.
"""
raise NotImplementedError

@property
def messages(self) -> Union[List[str], List[Dict[str, str]]]:
"""Returns the full chat history."""
raise NotImplementedError

def get_recent(
self,
top_k: int = 5,
as_text: bool = False,
raw: bool = False,
session_tag: Optional[str] = None,
) -> Union[List[str], List[Dict[str, str]]]:
"""Retreive the recent conversation history in sequential order.

Args:
top_k (int): The number of previous exchanges to return. Default is 5.
Note that one exchange contains both a prompt and response.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.
raw (bool): Whether to return the full Redis hash entry or just the
prompt and response
session_tag (str): Tag to be added to entries to link to a specific
conversation session. Defaults to instance ULID.

Returns:
Union[str, List[str]]: A single string transcription of the messages
or list of strings if as_text is false.

Raises:
ValueError: If top_k is not an integer greater than or equal to 0.
"""
raise NotImplementedError

def _format_context(
self, messages: List[Dict[str, Any]], as_text: bool
) -> Union[List[str], List[Dict[str, str]]]:
"""Extracts the prompt and response fields from the Redis hashes and
formats them as either flat dictionaries or strings.

Args:
messages (List[Dict[str, Any]]): The messages from the message history index.
as_text (bool): Whether to return the conversation as a single string,
or list of alternating prompts and responses.

Returns:
Union[str, List[str]]: A single string transcription of the messages
or list of strings if as_text is false.
"""
context = []

for message in messages:

chat_message = ChatMessage(**message)

if as_text:
context.append(chat_message.content)
else:
chat_message_dict = {
ROLE_FIELD_NAME: chat_message.role,
CONTENT_FIELD_NAME: chat_message.content,
}
if chat_message.tool_call_id is not None:
chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id

context.append(chat_message_dict) # type: ignore

return context

def store(
self, prompt: str, response: str, session_tag: Optional[str] = None
) -> None:
"""Insert a prompt:response pair into the message history. A timestamp
is associated with each exchange so that they can be later sorted
in sequential ordering after retrieval.

Args:
prompt (str): The user prompt to the LLM.
response (str): The corresponding LLM response.
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
"""
raise NotImplementedError

def add_messages(
self, messages: List[Dict[str, str]], session_tag: Optional[str] = None
) -> None:
"""Insert a list of prompts and responses into the message history.
A timestamp is associated with each so that they can be later sorted
in sequential ordering after retrieval.

Args:
messages (List[Dict[str, str]]): The list of user prompts and LLM responses.
session_tag (Optional[str]): The tag to mark the messages with. Defaults to None.
"""
raise NotImplementedError

def add_message(
self, message: Dict[str, str], session_tag: Optional[str] = None
) -> None:
"""Insert a single prompt or response into the message history.
A timestamp is associated with it so that it can be later sorted
in sequential ordering after retrieval.

Args:
message (Dict[str,str]): The user prompt or LLM response.
session_tag (Optional[str]): The tag to mark the message with. Defaults to None.
"""
raise NotImplementedError
Loading