Skip to content

Commit 8d792d8

Browse files
committed
Add and use ElasticsearchChatMessageHistory
1 parent a026d0f commit 8d792d8

File tree

2 files changed

+131
-23
lines changed

2 files changed

+131
-23
lines changed

notebooks/generative-ai/chatbot.ipynb

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
"source": [
296296
"## Chat with the chatbot 💬\n",
297297
"\n",
298-
"Let's initialize our chatbot. We'll define Elasticsearch as a store for retrieving documents, OpenAI as the LLM to interpret questions and summarize answers, then we'll pass these to the conversational chain."
298+
"Let's initialize our chatbot. We'll define Elasticsearch as a store for retrieving documents and for storing the chat session history, OpenAI as the LLM to interpret questions and summarize answers, then we'll pass these to the conversational chain."
299299
]
300300
},
301301
{
@@ -307,6 +307,8 @@
307307
"from langchain.vectorstores.elastic_vector_search import ElasticKnnSearch\n",
308308
"from langchain.llms import OpenAI\n",
309309
"from langchain.chains import ConversationalRetrievalChain\n",
310+
"from lib.elasticsearch_chat_message_history import ElasticsearchChatMessageHistory\n",
311+
"from uuid import uuid4\n",
310312
"\n",
311313
"store = ElasticKnnSearch(\n",
312314
" es_connection=elasticsearch_client,\n",
@@ -322,6 +324,13 @@
322324
" llm=llm,\n",
323325
" retriever=retriever,\n",
324326
" return_source_documents=True\n",
327+
")\n",
328+
"\n",
329+
"session_id = str(uuid4())\n",
330+
"chat_history = ElasticsearchChatMessageHistory(\n",
331+
" client=elasticsearch_client,\n",
332+
" session_id=session_id,\n",
333+
" index='workplace-docs-chat-history'\n",
325334
")"
326335
]
327336
},
@@ -343,33 +352,34 @@
343352
"name": "stdout",
344353
"output_type": "stream",
345354
"text": [
346-
"QUESTION: What does NASA stand for? \n",
347-
"ANSWER: NASA stands for North America South America. \n",
348-
"SUPPORTING DOCUMENTS: ['Sales Organization Overview', 'Code Of Conduct', 'Code Of Conduct', 'Swe Career Matrix']\n",
349-
"QUESTION: Which countries are part of it? \n",
350-
"ANSWER: The North America South America region includes the United States, Canada, Mexico, as well as Central and South America. \n",
351-
"SUPPORTING DOCUMENTS: ['Sales Organization Overview', 'Sales Organization Overview', 'Sales Organization Overview', 'Fy2024 Company Sales Strategy']\n",
352-
"QUESTION: Who are the team's leads? \n",
353-
"ANSWER: Laura Martinez is the Area Vice-President of North America, and Gary Johnson is the Area Vice-President of South America. \n",
354-
"SUPPORTING DOCUMENTS: ['Sales Organization Overview', 'Sales Organization Overview', 'Swe Career Matrix', 'Swe Career Matrix']\n"
355+
"[CHAT SESSION ID] 09116274-f852-4ae6-9617-c5aa2a17bbff\n",
356+
"[QUESTION] What does NASA stand for?\n",
357+
"[ANSWER] NASA stands for North America South America region.\n",
358+
" [SUPPORTING DOCUMENTS] ['Sales Organization Overview', 'Code Of Conduct', 'Code Of Conduct', 'Swe Career Matrix']\n",
359+
"[QUESTION] Which countries are part of it?\n",
360+
"[ANSWER] The North America South America region includes the United States, Canada, Mexico, as well as Central and South America.\n",
361+
" [SUPPORTING DOCUMENTS] ['Sales Organization Overview', 'Sales Organization Overview', 'Sales Organization Overview', 'Wfh Policy Update May 2023']\n",
362+
"[QUESTION] Who are the team's leads?\n",
363+
"[ANSWER] Laura Martinez is the Area Vice-President of North America, and Gary Johnson is the Area Vice-President of South America.\n",
364+
" [SUPPORTING DOCUMENTS] ['Sales Organization Overview', 'Swe Career Matrix', 'Sales Organization Overview', 'Swe Career Matrix']\n"
355365
]
356366
}
357367
],
358368
"source": [
359369
"# Define a convenience function for Q&A\n",
360-
"def ask(question, history):\n",
361-
" result = chat({\"question\": question, \"chat_history\": chat_history})\n",
362-
" print(\"QUESTION: \", question,\n",
363-
" \"\\nANSWER: \", result[\"answer\"],\n",
364-
" \"\\nSUPPORTING DOCUMENTS: \", list(map(lambda d: d.metadata[\"name\"], list(result[\"source_documents\"])))\n",
365-
" )\n",
366-
" history.append((question, result[\"answer\"]))\n",
367-
" \n",
368-
"chat_history = []\n",
369-
"\n",
370+
"def ask(question, chat_history):\n",
371+
" result = chat({\"question\": question, \"chat_history\": chat_history.messages})\n",
372+
" print(f\"\"\"[QUESTION] {question}\n",
373+
"[ANSWER] {result[\"answer\"]}\n",
374+
" [SUPPORTING DOCUMENTS] {list(map(lambda d: d.metadata[\"name\"], list(result[\"source_documents\"])))}\"\"\")\n",
375+
" chat_history.add_user_message(result[\"question\"])\n",
376+
" chat_history.add_ai_message(result[\"answer\"])\n",
377+
"\n",
378+
"# Chat away!\n",
379+
"print(f\"[CHAT SESSION ID] {session_id}\")\n",
370380
"ask(\"What does NASA stand for?\", chat_history)\n",
371381
"ask(\"Which countries are part of it?\", chat_history)\n",
372-
"ask(\"Who are the team's leads?\", chat_history)\n"
382+
"ask(\"Who are the team's leads?\", chat_history)"
373383
]
374384
},
375385
{
@@ -385,7 +395,23 @@
385395
"source": [
386396
"# (Optional) Clean up 🧹\n",
387397
"\n",
388-
"Once we're done, we can delete the Elasticsearch index."
398+
"Once we're done, we can clean up the chat history for this session..."
399+
]
400+
},
401+
{
402+
"cell_type": "code",
403+
"execution_count": null,
404+
"metadata": {},
405+
"outputs": [],
406+
"source": [
407+
"chat_history.clear()"
408+
]
409+
},
410+
{
411+
"cell_type": "markdown",
412+
"metadata": {},
413+
"source": [
414+
"... or delete the indices."
389415
]
390416
},
391417
{
@@ -394,7 +420,8 @@
394420
"metadata": {},
395421
"outputs": [],
396422
"source": [
397-
"elasticsearch_client.indices.delete(index='workplace-docs')"
423+
"elasticsearch_client.indices.delete(index='workplace-docs')\n",
424+
"elasticsearch_client.indices.delete(index='workplace-docs-chat-history')"
398425
]
399426
}
400427
],
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import json
2+
import logging
3+
from typing import List
4+
from elasticsearch import ApiError, Elasticsearch
5+
6+
from langchain.schema import BaseChatMessageHistory
7+
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
8+
9+
logger = logging.getLogger(__name__)
10+
11+
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
12+
"""Chat message history that stores history in Elasticsearch.
13+
14+
Args:
15+
client: Elasticsearch client.
16+
index: name of the index to use.
17+
session_id: arbitrary key that is used to store the messages
18+
of a single chat session.
19+
"""
20+
21+
def __init__(
22+
self,
23+
client: Elasticsearch,
24+
index: str,
25+
session_id: str,
26+
):
27+
self.client: Elasticsearch = client
28+
self.index: str = index
29+
self.session_id: str = session_id
30+
31+
if not client.indices.exists(index=index):
32+
client.indices.create(
33+
index=index,
34+
mappings={
35+
"properties": {
36+
"session_id": {"type": "keyword"},
37+
"history": {"type": "text"}
38+
}
39+
}
40+
)
41+
42+
@property
43+
def messages(self) -> List[BaseMessage]:
44+
"""Retrieve the messages from Elasticsearch"""
45+
try:
46+
result = self.client.search(
47+
index=self.index,
48+
query={"term": {"session_id": self.session_id}}
49+
)
50+
except ApiError as err:
51+
logger.error(err)
52+
53+
if result and len(result["hits"]["hits"]) > 0:
54+
items = [json.loads(document["_source"]["history"]) for document in result["hits"]["hits"]]
55+
else:
56+
items = []
57+
58+
return messages_from_dict(items)
59+
60+
def add_message(self, message: BaseMessage) -> None:
61+
"""Add a message to the chat session in Elasticsearch"""
62+
try:
63+
self.client.index(
64+
index=self.index,
65+
body={
66+
"session_id": self.session_id,
67+
"history": json.dumps(_message_to_dict(message))
68+
}
69+
)
70+
except ApiError as err:
71+
logger.error(err)
72+
73+
def clear(self) -> None:
74+
"""Clear session memory in Elasticsearch"""
75+
try:
76+
self.client.delete_by_query(
77+
index=self.index,
78+
query={"term": {"session_id": self.session_id}}
79+
)
80+
except ApiError as err:
81+
logger.error(err)

0 commit comments

Comments
 (0)