Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure CosmosDB memory #3434

Merged
merged 14 commits into from
Apr 25, 2023
6 changes: 6 additions & 0 deletions langchain/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
ConversationStringBufferMemory,
)
from langchain.memory.buffer_window import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
from langchain.memory.chat_message_histories.cosmos_db_aio import (
CosmosDBChatMessageHistoryAsync,
)
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
Expand Down Expand Up @@ -40,4 +44,6 @@
"DynamoDBChatMessageHistory",
"PostgresChatMessageHistory",
"VectorStoreRetrieverMemory",
"CosmosDBChatMessageHistory",
"CosmosDBChatMessageHistoryAsync",
]
6 changes: 6 additions & 0 deletions langchain/memory/chat_message_histories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from langchain.memory.chat_message_histories.cosmos_db import CosmosDBChatMessageHistory
from langchain.memory.chat_message_histories.cosmos_db_aio import (
CosmosDBChatMessageHistoryAsync,
)
from langchain.memory.chat_message_histories.dynamodb import DynamoDBChatMessageHistory
from langchain.memory.chat_message_histories.file import FileChatMessageHistory
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
Expand All @@ -8,4 +12,6 @@
"RedisChatMessageHistory",
"PostgresChatMessageHistory",
"FileChatMessageHistory",
"CosmosDBChatMessageHistory",
"CosmosDBChatMessageHistoryAsync",
]
134 changes: 134 additions & 0 deletions langchain/memory/chat_message_histories/cosmos_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Azure CosmosDB Memory History."""
import logging
from types import TracebackType
from typing import Optional, Type

from azure.cosmos import ContainerProxy, CosmosClient, PartitionKey
from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.identity import DefaultAzureCredential

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
messages_from_dict,
messages_to_dict,
)

logger = logging.getLogger(__name__)


class CosmosDBChatMessageHistory(BaseChatMessageHistory):
"""Chat history backed by Azure CosmosDB."""

def __init__(
self,
cosmos_endpoint: str,
cosmos_database: str,
cosmos_container: str,
credential: DefaultAzureCredential,
session_id: str,
user_id: str,
ttl: Optional[int] = None,
):
"""
Initializes a new instance of the CosmosDBChatMessageHistory class.

:param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account.
:param cosmos_database: The name of the database to use.
:param cosmos_container: The name of the container to use.
:param credential: The credential to use to authenticate to Azure Cosmos DB.
:param session_id: The session ID to use, can be overwritten while loading.
:param user_id: The user ID to use, can be overwritten while loading.
:param ttl: The time to live (in seconds) to use for documents in the container.
"""
self.cosmos_endpoint = cosmos_endpoint
self.cosmos_database = cosmos_database
self.cosmos_container = cosmos_container
self.credential = credential
self.session_id = session_id
self.user_id = user_id
self.ttl = ttl

self._client = CosmosClient(
url=self.cosmos_endpoint, credential=self.credential
)
self._container: Optional["ContainerProxy"] = None

def prepare_cosmos(self) -> None:
"""Prepare the CosmosDB client.

Use this function or the context manager to make sure your database is ready.
"""
database = self._client.create_database_if_not_exists(self.cosmos_database)
self._container = database.create_container_if_not_exists(
self.cosmos_container,
partition_key=PartitionKey("/user_id"),
default_ttl=self.ttl,
)
self.load_messages()

def __enter__(self) -> "CosmosDBChatMessageHistory":
"""Context manager entry point."""
self._client.__enter__()
self.prepare_cosmos()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Context manager exit"""
self.upsert_messages()
self._client.__exit__(exc_type, exc_val, traceback)

def load_messages(self) -> None:
"""Retrieve the messages from Cosmos"""
if not self._container:
raise ValueError("Container not initialized")
try:
item = self._container.read_item(
item=self.session_id, partition_key=self.user_id
)
except CosmosHttpResponseError:
logger.info("no session found")
return
if (
"messages" in item
and len(item["messages"]) > 0
and isinstance(item["messages"][0], list)
):
self.messages = messages_from_dict(item["messages"])

def add_user_message(self, message: str) -> None:
"""Add a user message to the memory."""
self.upsert_messages(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
"""Add a AI message to the memory."""
self.upsert_messages(AIMessage(content=message))

def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
"""Update the cosmosdb item."""
if new_message:
self.messages.append(new_message)
if not self._container:
raise ValueError("Container not initialized")
self._container.upsert_item(
body={
"id": self.session_id,
"user_id": self.user_id,
"messages": messages_to_dict(self.messages),
}
)

def clear(self) -> None:
"""Clear session memory from this memory and cosmos."""
self.messages = []
if self._container:
self._container.delete_item(
item=self.session_id, partition_key=self.user_id
)
156 changes: 156 additions & 0 deletions langchain/memory/chat_message_histories/cosmos_db_aio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Azure CosmosDB Memory History."""
import json
import logging
from types import TracebackType
from typing import Optional, Type

from azure.cosmos import PartitionKey
from azure.cosmos.aio import ContainerProxy, CosmosClient
from azure.cosmos.exceptions import CosmosHttpResponseError
from azure.identity import DefaultAzureCredential

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
messages_from_dict,
messages_to_dict,
)

logger = logging.getLogger(__name__)


class CosmosDBChatMessageHistoryAsync(BaseChatMessageHistory):
eavanvalkenburg marked this conversation as resolved.
Show resolved Hide resolved
"""Chat history backed by Azure CosmosDB, using async."""

def __init__(
self,
cosmos_endpoint: str,
cosmos_database: str,
cosmos_container: str,
credential: DefaultAzureCredential,
session_id: str,
user_id: str,
ttl: Optional[int] = None,
):
"""
Initializes a new instance of the CosmosDBChatMessageHistory class.

:param cosmos_endpoint: The connection endpoint for the Azure Cosmos DB account.
:param cosmos_database: The name of the database to use.
:param cosmos_container: The name of the container to use.
:param credential: The credential to use to authenticate to Azure Cosmos DB.
:param session_id: The session ID to use, can be overwritten while loading.
:param user_id: The user ID to use, can be overwritten while loading.
:param ttl: The time to live (in seconds) to use for documents in the container.
"""
self.cosmos_endpoint = cosmos_endpoint
self.cosmos_database = cosmos_database
self.cosmos_container = cosmos_container
self.credential = credential
self.session_id = session_id
self.user_id = user_id
self.ttl = ttl

self._client = CosmosClient(
url=self.cosmos_endpoint, credential=self.credential
)
self._container: Optional["ContainerProxy"] = None

async def __aenter__(self) -> "CosmosDBChatMessageHistoryAsync":
"""Async context manager entry point."""
await self._client.__aenter__()
database = await self._client.create_database_if_not_exists(
self.cosmos_database
)
self._container = await database.create_container_if_not_exists(
self.cosmos_container,
partition_key=PartitionKey("/user_id"),
default_ttl=self.ttl,
)
await self.load_messages()
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Async context manager exit"""
await self.upsert_messages()
await self._client.__aexit__(exc_type, exc_val, traceback)

async def load_messages(self) -> None:
"""Retrieve the messages from Cosmos"""
if not self._container:
raise ValueError("Container not initialized")
try:
item = await self._container.read_item(
item=self.session_id, partition_key=self.user_id
)
except CosmosHttpResponseError:
logger.info("no session found")
return None
if (
"messages" in item
and len(item["messages"]) > 0
and isinstance(item["messages"][0], list)
):
self.messages = messages_from_dict(item["messages"])

def add_user_message(self, message: str) -> None:
"""Add a user message to the memory.

Be careful this method does not store the message externally,
use upsert_messages (async) after this to store in Cosmos.
Alternatively use the a_ version of this method with async.
"""
self.messages.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
"""Add a AI message to the memory.

Be careful this method does not store the message externally,
use upsert_messages (async) after this to store in Cosmos.
Alternatively use the a_ version of this method with async.
"""
self.messages.append(AIMessage(content=message))

async def a_add_user_message(self, message: str) -> None:
"""Add a user message to the memory."""
await self.upsert_messages(HumanMessage(content=message))

async def a_add_ai_message(self, message: str) -> None:
"""Add a AI message to the memory."""
await self.upsert_messages(AIMessage(content=message))

async def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
"""Update the cosmosdb item."""
if new_message:
self.messages.append(new_message)
if not self._container:
raise ValueError("Container not initialized")
await self._container.upsert_item(
body={
"id": self.session_id,
"user_id": self.user_id,
"messages": json.dumps(messages_to_dict(self.messages)),
}
)

def clear(self) -> None:
"""Clear session memory from this memory.

Does not delete from Cosmos, use a_clear for that."""
self.messages = []

async def a_clear(self) -> None:
"""Clear session memory from Redis"""
if not self._container:
raise ValueError("Container not initialized")
self.messages = []
await self._container.delete_item(
item=self.session_id, partition_key=self.user_id
)
Loading