diff --git a/search4all.py b/search4all.py index 3d0bf0c..26bdc0b 100644 --- a/search4all.py +++ b/search4all.py @@ -44,6 +44,9 @@ # does not respond within this time, we will return an error. DEFAULT_SEARCH_ENGINE_TIMEOUT = 5 +# 默认记录的对话历史长度 +MAX_HISTORY_LEN = 10 + # If the user did not provide a query, we will use this default query. _default_query = "Who said 'live long and prosper'?" @@ -104,6 +107,33 @@ def get(self, key: str): def put(self, key: str, value: str): self._db[key] = value self._db.commit() + + def append(self, key: str, value): + """ 记录聊天历史 """ + self._db[key] = self._db.get(key, []) + # 最长记录的对话轮数 MAX_HISTORY_LEN + _ = self._db[key][-MAX_HISTORY_LEN:] + _.append(value) + self._db[key] = _ + self._db.commit() + +# 格式化输出部分 +def extract_all_sections(text: str): + # 定义正则表达式模式以匹配各部分 + sections_pattern = r"(.*?)__LLM_RESPONSE__(.*?)(__RELATED_QUESTIONS__(.*))?$" + + # 使用正则表达式查找各部分内容 + match = re.search(sections_pattern, text, re.DOTALL) + + # 从匹配结果中提取文本,如果没有匹配则返回None + if match: + search_results = match.group(1).strip() # 前置文本作为搜索结果 + llm_response = match.group(2).strip() # 问题回答部分 + related_questions = match.group(4).strip() if match.group(4) else "" # 相关问题文本,如果不存在则返回空字符串 + else: + search_results, llm_response, related_questions = None, None, None + + return search_results, llm_response, related_questions def search_with_search1api(query: str, search1api_key: str): """ @@ -417,6 +447,10 @@ async def server_init(_app, loop): _app.ctx.should_do_related_questions = bool( os.getenv("RELATED_QUESTIONS") in ("1", "yes", "true") ) + # 是否开始聊天历史的环境变量 + _app.ctx.should_do_chat_history = bool( + os.getenv("CHAT_HISTORY") in ("1", "yes", "true") + ) # Create httpx Session _app.ctx.http_session = httpx.AsyncClient( timeout=httpx.Timeout(connect=10, read=120, write=120, pool=10), @@ -588,21 +622,72 @@ async def query_function(request: sanic.Request): generate_related_questions = params.get("generate_related_questions", True) if not query: raise HTTPException("query must be provided.") + + # 定义传递给生成答案的聊天历史 以及搜索结果 + chat_history = [] + contexts = "" + # Note that, if uuid exists, we don't check if the stored query is the same # as the current query, and simply return the stored result. This is to enable # the user to share a searched link to others and have others see the same result. if search_uuid: - try: - result = await _app.loop.run_in_executor( - _app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), search_uuid - ) - return sanic.text(result) - except KeyError: - logger.info(f"Key {search_uuid} not found, will generate again.") - except Exception as e: - logger.error( - f"KV error: {e}\n{traceback.format_exc()}, will generate again." - ) + if _app.ctx.should_do_chat_history: + # 开启了历史记录,读取历史记录 + history = [] + try: + history = await _app.loop.run_in_executor( + _app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), f"{search_uuid}_history" + ) + result = await _app.loop.run_in_executor( + _app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), search_uuid + ) + # return sanic.text(result) + except KeyError: + logger.info(f"Key {search_uuid} not found, will generate again.") + except Exception as e: + logger.error( + f"KV error: {e}\n{traceback.format_exc()}, will generate again." + ) + # 如果存在历史记录 + if history: + # 获取最后一次记录 + last_entry = history[-1] + # 确定最后一次记录的数据完整性 + old_query, search_results, llm_response = last_entry.get("query", ""), last_entry.get("search_results", ""), last_entry.get("llm_response", "") + # 如果存在旧查询和搜索结果 + if old_query and search_results: + if old_query != query: + # 从历史记录中获取搜索结果(最后一条) + contexts = history[-1]["search_results"] + # 将历史聊天的提问和回答提取 + chat_history = [] + for entry in history: + if "query" in entry and "llm_response" in entry: + chat_history.append({"role": "user", "content": entry["query"]}) + chat_history.append({"role": "assistant", "content": entry["llm_response"]}) + else: + return sanic.text(result["txt"]) # 查询未改变,直接返回结果 + else: + try: + result = await _app.loop.run_in_executor( + _app.ctx.executor, lambda sid: _app.ctx.kv.get(sid), search_uuid + ) + # debug + if isinstance(result, dict): + # 只有相同的查询才返回同一个结果, 兼容多轮对话。 + if result["query"] == query: + return sanic.text(result["txt"]) + else: + # TODO: 兼容旧数据代码 之后删除 + # 旧数据强制刷新 + # return sanic.text(result) + pass + except KeyError: + logger.info(f"Key {search_uuid} not found, will generate again.") + except Exception as e: + logger.error( + f"KV error: {e}\n{traceback.format_exc()}, will generate again." + ) else: raise HTTPException("search_uuid must be provided.") @@ -619,9 +704,11 @@ async def query_function(request: sanic.Request): # query = query or _default_query # Basic attack protection: remove "[INST]" or "[/INST]" from the query query = re.sub(r"\[/?INST\]", "", query) - contexts = await _app.loop.run_in_executor( - _app.ctx.executor, _app.ctx.search_function, query - ) + # 开启聊天历史并且有有效数据 则不再重新请求搜索 + if not _app.ctx.should_do_chat_history or contexts in ("", None): + contexts = await _app.loop.run_in_executor( + _app.ctx.executor, _app.ctx.search_function, query + ) system_prompt = _rag_query_text.format( context="\n\n".join( @@ -630,6 +717,13 @@ async def query_function(request: sanic.Request): ) try: openai_client = new_async_client(_app) + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": query}, + ] + if chat_history and len(chat_history) % 2 == 0: + # 将历史插入到消息中 index = 1 的位置 + messages[1:1] = chat_history llm_response = await openai_client.chat.completions.create( model=_app.ctx.model, messages=[ @@ -664,8 +758,24 @@ async def query_function(request: sanic.Request): # Second, upload to KV. Note that if uploading to KV fails, we will silently # ignore it, because we don't want to affect the user experience. await response.eof() + if _app.ctx.should_do_chat_history: + # 保存聊天历史 + _search_results, _llm_response, _related_questions = await _app.loop.run_in_executor( + _app.ctx.executor, extract_all_sections, "".join(all_yielded_results) + ) + if _search_results: + _search_results = json.loads(_search_results) + if _related_questions: + _related_questions = json.loads(_related_questions) + _ = _app.ctx.executor.submit( + _app.ctx.kv.append, f"{search_uuid}_history", { + "query": query, + "search_results": _search_results, + "llm_response": _llm_response, + "related_questions": _related_questions + }) _ = _app.ctx.executor.submit( - _app.ctx.kv.put, search_uuid, "".join(all_yielded_results) + _app.ctx.kv.put, search_uuid, {"query": query, "txt": "".join(all_yielded_results)} # 原来的缓存是直接根据sid返回结果,开启聊天历史后 同一个sid存储多轮对话,因此需要存储 query 兼容多轮对话 ) diff --git a/web/src/app/components/search.tsx b/web/src/app/components/search.tsx index cde82ee..d95bbef 100644 --- a/web/src/app/components/search.tsx +++ b/web/src/app/components/search.tsx @@ -1,20 +1,34 @@ "use client"; import { getSearchUrl } from "@/app/utils/get-search-url"; -import { ArrowRight } from "lucide-react"; +import { ArrowRight, ArrowUp } from "lucide-react"; import { nanoid } from "nanoid"; import { useRouter } from "next/navigation"; import React, { FC, useState } from "react"; +import { useSearchParams } from "next/navigation"; -export const Search: FC = () => { +interface SearchProps { + useContinueButton?: boolean; // true: 使用“继续对话”按钮; false: 使用“新的搜索”按钮 +} +export const Search: FC = ({ useContinueButton = false }) => { const [value, setValue] = useState(""); const router = useRouter(); + const searchParams = useSearchParams(); + const old_rid = decodeURIComponent(searchParams.get("rid") || ""); + const handleNewSearch = () => { + // 可以在这里重置任何需要的状态,以准备一个新的搜索 + if (value) { + setValue(""); // 清空搜索框 + router.push(getSearchUrl(encodeURIComponent(value), nanoid())); + } + }; return (
{ e.preventDefault(); if (value) { setValue(""); - router.push(getSearchUrl(encodeURIComponent(value), nanoid())); + const rid = useContinueButton ? old_rid : nanoid(); + router.push(getSearchUrl(encodeURIComponent(value), rid)); } }} > @@ -34,7 +48,7 @@ export const Search: FC = () => { type="submit" className="w-auto py-1 px-2 bg-black border-black text-white fill-white active:scale-95 border overflow-hidden relative rounded-xl" > - + {useContinueButton ? : } diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index 48800a8..c957af3 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -18,7 +18,7 @@ export default function SearchPage() {
- +