-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
chatreadretrieveread.py
246 lines (222 loc) · 10.7 KB
/
chatreadretrieveread.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
from typing import Any, Coroutine, List, Literal, Optional, Union, overload
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import VectorQuery
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai_messages_token_helper import build_messages, get_token_limit
from approaches.approach import ThoughtStep
from approaches.chatapproach import ChatApproach
from core.authentication import AuthenticationHelper
class ChatReadRetrieveReadApproach(ChatApproach):
"""
A multi-step approach that first uses OpenAI to turn the user's question into a search query,
then uses Azure AI Search to retrieve relevant documents, and then sends the conversation history,
original user question, and search results to OpenAI to generate a response.
"""
def __init__(
self,
*,
search_client: SearchClient,
auth_helper: AuthenticationHelper,
openai_client: AsyncOpenAI,
chatgpt_model: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_model: str,
embedding_dimensions: int,
sourcepage_field: str,
content_field: str,
query_language: str,
query_speller: str,
):
self.search_client = search_client
self.openai_client = openai_client
self.auth_helper = auth_helper
self.chatgpt_model = chatgpt_model
self.chatgpt_deployment = chatgpt_deployment
self.embedding_deployment = embedding_deployment
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.query_language = query_language
self.query_speller = query_speller
self.chatgpt_token_limit = get_token_limit(chatgpt_model, default_to_minimum=self.ALLOW_NON_GPT_MODELS)
@property
def system_message_chat_conversation(self):
return """Assistant helps the company employees with their healthcare plan questions, and questions about the employee handbook. Be brief in your answers.
Answer ONLY with the facts listed in the list of sources below. If there isn't enough information below, say you don't know. Do not generate answers that don't use the sources below. If asking a clarifying question to the user would help, ask the question.
If the question is not in English, answer in the language used in the question.
Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. Use square brackets to reference the source, for example [info1.txt]. Don't combine sources, list each source separately, for example [info1.txt][info2.pdf].
{follow_up_questions_prompt}
{injected_prompt}
"""
@overload
async def run_until_final_call(
self,
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: Literal[False],
) -> tuple[dict[str, Any], Coroutine[Any, Any, ChatCompletion]]: ...
@overload
async def run_until_final_call(
self,
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: Literal[True],
) -> tuple[dict[str, Any], Coroutine[Any, Any, AsyncStream[ChatCompletionChunk]]]: ...
async def run_until_final_call(
self,
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: bool = False,
) -> tuple[dict[str, Any], Coroutine[Any, Any, Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]]]:
seed = overrides.get("seed", None)
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
use_semantic_captions = True if overrides.get("semantic_captions") else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
original_user_query = messages[-1]["content"]
if not isinstance(original_user_query, str):
raise ValueError("The most recent message content must be a string.")
user_query_request = "Generate search query for: " + original_user_query
tools: List[ChatCompletionToolParam] = [
{
"type": "function",
"function": {
"name": "search_sources",
"description": "Retrieve sources from the Azure AI Search index",
"parameters": {
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
}
},
"required": ["search_query"],
},
},
}
]
# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
query_response_token_limit = 100
query_messages = build_messages(
model=self.chatgpt_model,
system_prompt=self.query_prompt_template,
tools=tools,
few_shots=self.query_prompt_few_shots,
past_messages=messages[:-1],
new_user_content=user_query_request,
max_tokens=self.chatgpt_token_limit - query_response_token_limit,
fallback_to_default=self.ALLOW_NON_GPT_MODELS,
)
chat_completion: ChatCompletion = await self.openai_client.chat.completions.create(
messages=query_messages, # type: ignore
# Azure OpenAI takes the deployment name as the model name
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
temperature=0.0, # Minimize creativity for search query generation
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, setting too high may affect performance
n=1,
tools=tools,
seed=seed,
)
query_text = self.get_search_query(chat_completion, original_user_query)
# STEP 2: Retrieve relevant documents from the search index with the GPT optimized query
# If retrieval mode includes vectors, compute an embedding for the query
vectors: list[VectorQuery] = []
if use_vector_search:
vectors.append(await self.compute_text_embedding(query_text))
results = await self.search(
top,
query_text,
filter,
vectors,
use_text_search,
use_vector_search,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
)
sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=False)
content = "\n".join(sources_content)
# STEP 3: Generate a contextual and content specific answer using the search results and chat history
# Allow client to replace the entire prompt, or to inject into the exiting prompt using >>>
system_message = self.get_system_prompt(
overrides.get("prompt_template"),
self.follow_up_questions_prompt_content if overrides.get("suggest_followup_questions") else "",
)
response_token_limit = 1024
messages = build_messages(
model=self.chatgpt_model,
system_prompt=system_message,
past_messages=messages[:-1],
# Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt.
new_user_content=original_user_query + "\n\nSources:\n" + content,
max_tokens=self.chatgpt_token_limit - response_token_limit,
fallback_to_default=self.ALLOW_NON_GPT_MODELS,
)
data_points = {"text": sources_content}
extra_info = {
"data_points": data_points,
"thoughts": [
ThoughtStep(
"Prompt to generate search query",
query_messages,
(
{"model": self.chatgpt_model, "deployment": self.chatgpt_deployment}
if self.chatgpt_deployment
else {"model": self.chatgpt_model}
),
),
ThoughtStep(
"Search using generated search query",
query_text,
{
"use_semantic_captions": use_semantic_captions,
"use_semantic_ranker": use_semantic_ranker,
"top": top,
"filter": filter,
"use_vector_search": use_vector_search,
"use_text_search": use_text_search,
},
),
ThoughtStep(
"Search results",
[result.serialize_for_results() for result in results],
),
ThoughtStep(
"Prompt to generate answer",
messages,
(
{"model": self.chatgpt_model, "deployment": self.chatgpt_deployment}
if self.chatgpt_deployment
else {"model": self.chatgpt_model}
),
),
],
}
chat_coroutine = self.openai_client.chat.completions.create(
# Azure OpenAI takes the deployment name as the model name
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
messages=messages,
temperature=overrides.get("temperature", 0.3),
max_tokens=response_token_limit,
n=1,
stream=should_stream,
seed=seed,
)
return (extra_info, chat_coroutine)