|
295 | 295 | "source": [ |
296 | 296 | "## Chat with the chatbot 💬\n", |
297 | 297 | "\n", |
298 | | - "Let's initialize our chatbot. We'll define Elasticsearch as a store for retrieving documents, OpenAI as the LLM to interpret questions and summarize answers, then we'll pass these to the conversational chain." |
| 298 | + "Let's initialize our chatbot. We'll define Elasticsearch as a store for retrieving documents and for storing the chat session history, OpenAI as the LLM to interpret questions and summarize answers, then we'll pass these to the conversational chain." |
299 | 299 | ] |
300 | 300 | }, |
301 | 301 | { |
|
307 | 307 | "from langchain.vectorstores.elastic_vector_search import ElasticKnnSearch\n", |
308 | 308 | "from langchain.llms import OpenAI\n", |
309 | 309 | "from langchain.chains import ConversationalRetrievalChain\n", |
| 310 | + "from lib.elasticsearch_chat_message_history import ElasticsearchChatMessageHistory\n", |
| 311 | + "from uuid import uuid4\n", |
310 | 312 | "\n", |
311 | 313 | "store = ElasticKnnSearch(\n", |
312 | 314 | " es_connection=elasticsearch_client,\n", |
|
322 | 324 | " llm=llm,\n", |
323 | 325 | " retriever=retriever,\n", |
324 | 326 | " return_source_documents=True\n", |
| 327 | + ")\n", |
| 328 | + "\n", |
| 329 | + "session_id = str(uuid4())\n", |
| 330 | + "chat_history = ElasticsearchChatMessageHistory(\n", |
| 331 | + " client=elasticsearch_client,\n", |
| 332 | + " session_id=session_id,\n", |
| 333 | + " index='workplace-docs-chat-history'\n", |
325 | 334 | ")" |
326 | 335 | ] |
327 | 336 | }, |
|
343 | 352 | "name": "stdout", |
344 | 353 | "output_type": "stream", |
345 | 354 | "text": [ |
346 | | - "QUESTION: What does NASA stand for? \n", |
347 | | - "ANSWER: NASA stands for North America South America. \n", |
348 | | - "SUPPORTING DOCUMENTS: ['Sales Organization Overview', 'Code Of Conduct', 'Code Of Conduct', 'Swe Career Matrix']\n", |
349 | | - "QUESTION: Which countries are part of it? \n", |
350 | | - "ANSWER: The North America South America region includes the United States, Canada, Mexico, as well as Central and South America. \n", |
351 | | - "SUPPORTING DOCUMENTS: ['Sales Organization Overview', 'Sales Organization Overview', 'Sales Organization Overview', 'Fy2024 Company Sales Strategy']\n", |
352 | | - "QUESTION: Who are the team's leads? \n", |
353 | | - "ANSWER: Laura Martinez is the Area Vice-President of North America, and Gary Johnson is the Area Vice-President of South America. \n", |
354 | | - "SUPPORTING DOCUMENTS: ['Sales Organization Overview', 'Sales Organization Overview', 'Swe Career Matrix', 'Swe Career Matrix']\n" |
| 355 | + "[CHAT SESSION ID] 09116274-f852-4ae6-9617-c5aa2a17bbff\n", |
| 356 | + "[QUESTION] What does NASA stand for?\n", |
| 357 | + "[ANSWER] NASA stands for North America South America region.\n", |
| 358 | + " [SUPPORTING DOCUMENTS] ['Sales Organization Overview', 'Code Of Conduct', 'Code Of Conduct', 'Swe Career Matrix']\n", |
| 359 | + "[QUESTION] Which countries are part of it?\n", |
| 360 | + "[ANSWER] The North America South America region includes the United States, Canada, Mexico, as well as Central and South America.\n", |
| 361 | + " [SUPPORTING DOCUMENTS] ['Sales Organization Overview', 'Sales Organization Overview', 'Sales Organization Overview', 'Wfh Policy Update May 2023']\n", |
| 362 | + "[QUESTION] Who are the team's leads?\n", |
| 363 | + "[ANSWER] Laura Martinez is the Area Vice-President of North America, and Gary Johnson is the Area Vice-President of South America.\n", |
| 364 | + " [SUPPORTING DOCUMENTS] ['Sales Organization Overview', 'Swe Career Matrix', 'Sales Organization Overview', 'Swe Career Matrix']\n" |
355 | 365 | ] |
356 | 366 | } |
357 | 367 | ], |
358 | 368 | "source": [ |
359 | 369 | "# Define a convenience function for Q&A\n", |
360 | | - "def ask(question, history):\n", |
361 | | - " result = chat({\"question\": question, \"chat_history\": chat_history})\n", |
362 | | - " print(\"QUESTION: \", question,\n", |
363 | | - " \"\\nANSWER: \", result[\"answer\"],\n", |
364 | | - " \"\\nSUPPORTING DOCUMENTS: \", list(map(lambda d: d.metadata[\"name\"], list(result[\"source_documents\"])))\n", |
365 | | - " )\n", |
366 | | - " history.append((question, result[\"answer\"]))\n", |
367 | | - " \n", |
368 | | - "chat_history = []\n", |
369 | | - "\n", |
| 370 | + "def ask(question, chat_history):\n", |
| 371 | + " result = chat({\"question\": question, \"chat_history\": chat_history.messages})\n", |
| 372 | + " print(f\"\"\"[QUESTION] {question}\n", |
| 373 | + "[ANSWER] {result[\"answer\"]}\n", |
| 374 | + " [SUPPORTING DOCUMENTS] {list(map(lambda d: d.metadata[\"name\"], list(result[\"source_documents\"])))}\"\"\")\n", |
| 375 | + " chat_history.add_user_message(result[\"question\"])\n", |
| 376 | + " chat_history.add_ai_message(result[\"answer\"])\n", |
| 377 | + "\n", |
| 378 | + "# Chat away!\n", |
| 379 | + "print(f\"[CHAT SESSION ID] {session_id}\")\n", |
370 | 380 | "ask(\"What does NASA stand for?\", chat_history)\n", |
371 | 381 | "ask(\"Which countries are part of it?\", chat_history)\n", |
372 | | - "ask(\"Who are the team's leads?\", chat_history)\n" |
| 382 | + "ask(\"Who are the team's leads?\", chat_history)" |
373 | 383 | ] |
374 | 384 | }, |
375 | 385 | { |
|
385 | 395 | "source": [ |
386 | 396 | "# (Optional) Clean up 🧹\n", |
387 | 397 | "\n", |
388 | | - "Once we're done, we can delete the Elasticsearch index." |
| 398 | + "Once we're done, we can clean up the chat history for this session..." |
| 399 | + ] |
| 400 | + }, |
| 401 | + { |
| 402 | + "cell_type": "code", |
| 403 | + "execution_count": null, |
| 404 | + "metadata": {}, |
| 405 | + "outputs": [], |
| 406 | + "source": [ |
| 407 | + "chat_history.clear()" |
| 408 | + ] |
| 409 | + }, |
| 410 | + { |
| 411 | + "cell_type": "markdown", |
| 412 | + "metadata": {}, |
| 413 | + "source": [ |
| 414 | + "... or delete the indices." |
389 | 415 | ] |
390 | 416 | }, |
391 | 417 | { |
|
394 | 420 | "metadata": {}, |
395 | 421 | "outputs": [], |
396 | 422 | "source": [ |
397 | | - "elasticsearch_client.indices.delete(index='workplace-docs')" |
| 423 | + "elasticsearch_client.indices.delete(index='workplace-docs')\n", |
| 424 | + "elasticsearch_client.indices.delete(index='workplace-docs-chat-history')" |
398 | 425 | ] |
399 | 426 | } |
400 | 427 | ], |
|
0 commit comments