Skip to content

Commit 14d0882

Browse files
committed
Add chat session support and error handling
1 parent d4a0001 commit 14d0882

File tree

1 file changed

+16
-10
lines changed
  • example-apps/workplace-search/api

1 file changed

+16
-10
lines changed

example-apps/workplace-search/api/app.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from elasticsearch import Elasticsearch
22
from lib.elasticsearch_chat_message_history import ElasticsearchChatMessageHistory
3-
from flask import Flask, request, Response
3+
from flask import Flask, jsonify, request, Response
44
from langchain.callbacks.base import BaseCallbackHandler
55
from langchain.chains import ConversationalRetrievalChain
66
from langchain.chat_models import ChatOpenAI
@@ -20,6 +20,7 @@
2020
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
2121

2222
POISON_MESSAGE = "~~~END~~~"
23+
SESSION_ID_TAG = "[SESSION_ID]"
2324
SOURCE_TAG = "[SOURCE]"
2425
DONE_TAG = "[DONE]"
2526

@@ -90,15 +91,7 @@ def on_llm_end(self, response, *, run_id, parent_run_id = None, **kwargs):
9091
retriever=store.as_retriever(),
9192
return_source_documents=True,
9293
combine_docs_chain_kwargs={'prompt': qa_prompt},
93-
verbose=True
94-
)
95-
96-
session_id = str(uuid4())
97-
print('Starting chat with session ID: ', session_id)
98-
chat_history = ElasticsearchChatMessageHistory(
99-
client=elasticsearch_client,
100-
index=INDEX_CHAT_HISTORY,
101-
session_id=session_id
94+
# verbose=True
10295
)
10396

10497
stream_queue = Queue()
@@ -125,8 +118,21 @@ def ask_question(question, queue, chat_history):
125118
def api_chat():
126119
request_json = request.get_json()
127120
question = request_json.get("question")
121+
if question is None:
122+
return jsonify({"msg": "Missing question from request JSON"}), 400
123+
124+
session_id = request.args.get('session_id', str(uuid4()))
125+
126+
print('Chat session ID: ', session_id)
127+
chat_history = ElasticsearchChatMessageHistory(
128+
client=elasticsearch_client,
129+
index=INDEX_CHAT_HISTORY,
130+
session_id=session_id
131+
)
128132

129133
def generate(queue: Queue):
134+
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
135+
130136
message = None
131137
while True:
132138
message = queue.get()

0 commit comments

Comments
 (0)