From 95f809187e84aacd7524455587e211829f101190 Mon Sep 17 00:00:00 2001 From: KevinHuSh Date: Thu, 16 May 2024 20:14:53 +0800 Subject: [PATCH] add stream chat (#811) ### What problem does this PR solve? #709 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/api_app.py | 89 +++++++++++------------ api/apps/conversation_app.py | 50 +++++++++++-- api/apps/system_app.py | 67 +++++++++++++++++ api/db/services/dialog_service.py | 83 ++++++++++++--------- api/db/services/document_service.py | 4 +- api/db/services/llm_service.py | 12 +++- api/utils/api_utils.py | 4 -- rag/llm/chat_model.py | 108 +++++++++++++++++++++++++++- rag/svr/task_executor.py | 25 ++++--- rag/utils/es_conn.py | 3 + rag/utils/minio_conn.py | 10 +++ rag/utils/redis_conn.py | 4 ++ 12 files changed, 355 insertions(+), 104 deletions(-) create mode 100644 api/apps/system_app.py diff --git a/api/apps/api_app.py b/api/apps/api_app.py index ca5fb8d5db..9f80996d99 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json import os import re from datetime import datetime, timedelta -from flask import request +from flask import request, Response from flask_login import login_required, current_user from api.db import FileType, ParserType @@ -31,11 +32,11 @@ from api.utils import get_uuid, current_timestamp, datetime_format from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request from itsdangerous import URLSafeTimedSerializer -from api.db.services.task_service import TaskService, queue_tasks + from api.utils.file_utils import filename_type, thumbnail from rag.utils.minio_conn import MINIO -from api.db.db_models import Task -from api.db.services.file2document_service import File2DocumentService + + def generate_confirmation_token(tenent_id): serializer = URLSafeTimedSerializer(tenent_id) return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34] @@ -164,6 +165,7 @@ def completion(): e, conv = API4ConversationService.get_by_id(req["conversation_id"]) if not e: return get_data_error_result(retmsg="Conversation not found!") + if "quote" not in req: req["quote"] = False msg = [] for m in req["messages"]: @@ -180,13 +182,45 @@ def completion(): return get_data_error_result(retmsg="Dialog not found!") del req["conversation_id"] del req["messages"] - ans = chat(dia, msg, **req) + if not conv.reference: conv.reference = [] - conv.reference.append(ans["reference"]) - conv.message.append({"role": "assistant", "content": ans["answer"]}) - API4ConversationService.append_message(conv.id, conv.to_dict()) - return get_json_result(data=ans) + conv.message.append({"role": "assistant", "content": ""}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + def fillin_conv(ans): + nonlocal conv + if not conv.reference: + conv.reference.append(ans["reference"]) + else: conv.reference[-1] = ans["reference"] + conv.message[-1] = {"role": "assistant", "content": ans["answer"]} + + def stream(): + nonlocal dia, msg, req, conv + try: + for ans in chat(dia, msg, True, **req): + fillin_conv(ans) + yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" + API4ConversationService.append_message(conv.id, conv.to_dict()) + except Exception as e: + yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), + "data": {"answer": "**ERROR**: "+str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + + if req.get("stream", True): + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + else: + ans = chat(dia, msg, False, **req) + fillin_conv(ans) + API4ConversationService.append_message(conv.id, conv.to_dict()) + return get_json_result(data=ans) + except Exception as e: return server_error_response(e) @@ -229,7 +263,6 @@ def upload(): return get_json_result( data=False, retmsg='No file part!', retcode=RetCode.ARGUMENT_ERROR) - file = request.files['file'] if file.filename == '': return get_json_result( @@ -253,7 +286,6 @@ def upload(): location += "_" blob = request.files['file'].read() MINIO.put(kb_id, location, blob) - doc = { "id": get_uuid(), "kb_id": kb.id, @@ -266,42 +298,11 @@ def upload(): "size": len(blob), "thumbnail": thumbnail(filename, blob) } - - form_data=request.form - if "parser_id" in form_data.keys(): - if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]: - doc["parser_id"] = request.form.get("parser_id").strip() if doc["type"] == FileType.VISUAL: doc["parser_id"] = ParserType.PICTURE.value if re.search(r"\.(ppt|pptx|pages)$", filename): doc["parser_id"] = ParserType.PRESENTATION.value - - doc_result = DocumentService.insert(doc) - + doc = DocumentService.insert(doc) + return get_json_result(data=doc.to_json()) except Exception as e: return server_error_response(e) - - if "run" in form_data.keys(): - if request.form.get("run").strip() == "1": - try: - info = {"run": 1, "progress": 0} - info["progress_msg"] = "" - info["chunk_num"] = 0 - info["token_num"] = 0 - DocumentService.update_by_id(doc["id"], info) - # if str(req["run"]) == TaskStatus.CANCEL.value: - tenant_id = DocumentService.get_tenant_id(doc["id"]) - if not tenant_id: - return get_data_error_result(retmsg="Tenant not found!") - - #e, doc = DocumentService.get_by_id(doc["id"]) - TaskService.filter_delete([Task.doc_id == doc["id"]]) - e, doc = DocumentService.get_by_id(doc["id"]) - doc = doc.to_dict() - doc["tenant_id"] = tenant_id - bucket, name = File2DocumentService.get_minio_address(doc_id=doc["id"]) - queue_tasks(doc, bucket, name) - except Exception as e: - return server_error_response(e) - - return get_json_result(data=doc_result.to_json()) \ No newline at end of file diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 2bb813cb6f..ed52500441 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from flask import request +from flask import request, Response, jsonify from flask_login import login_required from api.db.services.dialog_service import DialogService, ConversationService, chat from api.utils.api_utils import server_error_response, get_data_error_result, validate_request from api.utils import get_uuid from api.utils.api_utils import get_json_result +import json @manager.route('/set', methods=['POST']) @@ -103,9 +104,12 @@ def list_convsersation(): @manager.route('/completion', methods=['POST']) @login_required -@validate_request("conversation_id", "messages") +#@validate_request("conversation_id", "messages") def completion(): req = request.json + #req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [ + # {"role": "user", "content": "上海有吗?"} + #]} msg = [] for m in req["messages"]: if m["role"] == "system": @@ -123,13 +127,45 @@ def completion(): return get_data_error_result(retmsg="Dialog not found!") del req["conversation_id"] del req["messages"] - ans = chat(dia, msg, **req) + if not conv.reference: conv.reference = [] - conv.reference.append(ans["reference"]) - conv.message.append({"role": "assistant", "content": ans["answer"]}) - ConversationService.update_by_id(conv.id, conv.to_dict()) - return get_json_result(data=ans) + conv.message.append({"role": "assistant", "content": ""}) + conv.reference.append({"chunks": [], "doc_aggs": []}) + + def fillin_conv(ans): + nonlocal conv + if not conv.reference: + conv.reference.append(ans["reference"]) + else: conv.reference[-1] = ans["reference"] + conv.message[-1] = {"role": "assistant", "content": ans["answer"]} + + def stream(): + nonlocal dia, msg, req, conv + try: + for ans in chat(dia, msg, True, **req): + fillin_conv(ans) + yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" + ConversationService.update_by_id(conv.id, conv.to_dict()) + except Exception as e: + yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e), + "data": {"answer": "**ERROR**: "+str(e), "reference": []}}, + ensure_ascii=False) + "\n\n" + yield "data:"+json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n" + + if req.get("stream", True): + resp = Response(stream(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + else: + ans = chat(dia, msg, False, **req) + fillin_conv(ans) + ConversationService.update_by_id(conv.id, conv.to_dict()) + return get_json_result(data=ans) except Exception as e: return server_error_response(e) diff --git a/api/apps/system_app.py b/api/apps/system_app.py new file mode 100644 index 0000000000..933d1a7444 --- /dev/null +++ b/api/apps/system_app.py @@ -0,0 +1,67 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +# +from flask_login import login_required + +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.utils.api_utils import get_json_result +from api.versions import get_rag_version +from rag.settings import SVR_QUEUE_NAME +from rag.utils.es_conn import ELASTICSEARCH +from rag.utils.minio_conn import MINIO +from timeit import default_timer as timer + +from rag.utils.redis_conn import REDIS_CONN + + +@manager.route('/version', methods=['GET']) +@login_required +def version(): + return get_json_result(data=get_rag_version()) + + +@manager.route('/status', methods=['GET']) +@login_required +def status(): + res = {} + st = timer() + try: + res["es"] = ELASTICSEARCH.health() + res["es"]["elapsed"] = "{:.1f}".format((timer() - st)*1000.) + except Exception as e: + res["es"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + + st = timer() + try: + MINIO.health() + res["minio"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)} + except Exception as e: + res["minio"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + + st = timer() + try: + KnowledgebaseService.get_by_id("x") + res["mysql"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.)} + except Exception as e: + res["mysql"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + + st = timer() + try: + qinfo = REDIS_CONN.health(SVR_QUEUE_NAME) + res["redis"] = {"status": "green", "elapsed": "{:.1f}".format((timer() - st)*1000.), "pending": qinfo["pending"]} + except Exception as e: + res["redis"] = {"status": "red", "elapsed": "{:.1f}".format((timer() - st)*1000.), "error": str(e)} + + return get_json_result(data=res) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index ed8b050395..8d928afab4 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -14,6 +14,7 @@ # limitations under the License. # import re +from copy import deepcopy from api.db import LLMType from api.db.db_models import Dialog, Conversation @@ -71,7 +72,7 @@ def count(): return max_length, msg -def chat(dialog, messages, **kwargs): +def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." llm = LLMService.query(llm_name=dialog.llm_id) if not llm: @@ -82,7 +83,10 @@ def chat(dialog, messages, **kwargs): else: max_tokens = llm[0].max_tokens kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) embd_nms = list(set([kb.embd_id for kb in kbs])) - assert len(embd_nms) == 1, "Knowledge bases use different embedding models." + if len(embd_nms) != 1: + if stream: + yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} + return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []} questions = [m["content"] for m in messages if m["role"] == "user"] embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0]) @@ -94,7 +98,9 @@ def chat(dialog, messages, **kwargs): if field_map: chat_logger.info("Use SQL to retrieval:{}".format(questions[-1])) ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True)) - if ans: return ans + if ans: + yield ans + return for p in prompt_config["parameters"]: if p["key"] == "knowledge": @@ -118,8 +124,9 @@ def chat(dialog, messages, **kwargs): "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) if not knowledges and prompt_config.get("empty_response"): - return { - "answer": prompt_config["empty_response"], "reference": kbinfos} + if stream: + yield {"answer": prompt_config["empty_response"], "reference": kbinfos} + return {"answer": prompt_config["empty_response"], "reference": kbinfos} kwargs["knowledge"] = "\n".join(knowledges) gen_conf = dialog.llm_setting @@ -130,33 +137,45 @@ def chat(dialog, messages, **kwargs): gen_conf["max_tokens"] = min( gen_conf["max_tokens"], max_tokens - used_token_count) - answer = chat_mdl.chat( - prompt_config["system"].format( - **kwargs), msg, gen_conf) - chat_logger.info("User: {}|Assistant: {}".format( - msg[-1]["content"], answer)) - - if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): - answer, idx = retrievaler.insert_citations(answer, - [ck["content_ltks"] - for ck in kbinfos["chunks"]], - [ck["vector"] - for ck in kbinfos["chunks"]], - embd_mdl, - tkweight=1 - dialog.vector_similarity_weight, - vtweight=dialog.vector_similarity_weight) - idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) - recall_docs = [ - d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] - if not recall_docs: recall_docs = kbinfos["doc_aggs"] - kbinfos["doc_aggs"] = recall_docs - - for c in kbinfos["chunks"]: - if c.get("vector"): - del c["vector"] - if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: - answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" - return {"answer": answer, "reference": kbinfos} + + def decorate_answer(answer): + nonlocal prompt_config, knowledges, kwargs, kbinfos + if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): + answer, idx = retrievaler.insert_citations(answer, + [ck["content_ltks"] + for ck in kbinfos["chunks"]], + [ck["vector"] + for ck in kbinfos["chunks"]], + embd_mdl, + tkweight=1 - dialog.vector_similarity_weight, + vtweight=dialog.vector_similarity_weight) + idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) + recall_docs = [ + d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] + if not recall_docs: recall_docs = kbinfos["doc_aggs"] + kbinfos["doc_aggs"] = recall_docs + + refs = deepcopy(kbinfos) + for c in refs["chunks"]: + if c.get("vector"): + del c["vector"] + if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api")>=0: + answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" + return {"answer": answer, "reference": refs} + + if stream: + answer = "" + for ans in chat_mdl.chat_streamly(prompt_config["system"].format(**kwargs), msg, gen_conf): + answer = ans + yield {"answer": answer, "reference": {}} + yield decorate_answer(answer) + else: + answer = chat_mdl.chat( + prompt_config["system"].format( + **kwargs), msg, gen_conf) + chat_logger.info("User: {}|Assistant: {}".format( + msg[-1]["content"], answer)) + return decorate_answer(answer) def use_sql(question, field_map, tenant_id, chat_mdl, quota=True): diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 70e7af77c5..d85c15070f 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -43,7 +43,7 @@ def get_by_kb_id(cls, kb_id, page_number, items_per_page, docs = cls.model.select().where( (cls.model.kb_id == kb_id), (fn.LOWER(cls.model.name).contains(keywords.lower())) - ) + ) else: docs = cls.model.select().where(cls.model.kb_id == kb_id) count = docs.count() @@ -75,7 +75,7 @@ def insert(cls, doc): def delete(cls, doc): e, kb = KnowledgebaseService.get_by_id(doc.kb_id) if not KnowledgebaseService.update_by_id( - kb.id, {"doc_num": kb.doc_num - 1}): + kb.id, {"doc_num": max(0, kb.doc_num - 1)}): raise RuntimeError("Database error (Knowledgebase)!") return cls.delete_by_id(doc.id) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index a287cd6745..5129bb798f 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -172,8 +172,18 @@ def describe(self, image, max_tokens=300): def chat(self, system, history, gen_conf): txt, used_tokens = self.mdl.chat(system, history, gen_conf) - if TenantLLMService.increase_usage( + if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, used_tokens, self.llm_name): database_logger.error( "Can't update token usage for {}/CHAT".format(self.tenant_id)) return txt + + def chat_streamly(self, system, history, gen_conf): + for txt in self.mdl.chat_streamly(system, history, gen_conf): + if isinstance(txt, int): + if not TenantLLMService.increase_usage( + self.tenant_id, self.llm_type, txt, self.llm_name): + database_logger.error( + "Can't update token usage for {}/CHAT".format(self.tenant_id)) + return + yield txt diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index 1ce3664d01..df8d6dfe19 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -25,7 +25,6 @@ from werkzeug.http import HTTP_STATUS_CODES from api.utils import json_dumps -from api.versions import get_rag_version from api.settings import RetCode from api.settings import ( REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC, @@ -84,9 +83,6 @@ def request(**kwargs): return sess.send(prepped, stream=stream, timeout=timeout) -rag_version = get_rag_version() or '' - - def get_exponential_backoff_interval(retries, full_jitter=False): """Calculate the exponential backoff wait time.""" # Will be zero if factor equals 0 diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 797d3fea16..f8e741e764 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -20,7 +20,6 @@ import openai from ollama import Client from rag.nlp import is_english -from rag.utils import num_tokens_from_string class Base(ABC): @@ -44,6 +43,31 @@ def chat(self, system, history, gen_conf): except openai.APIError as e: return "**ERROR**: " + str(e), 0 + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + ans = "" + total_tokens = 0 + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + stream=True, + **gen_conf) + for resp in response: + if not resp.choices[0].delta.content:continue + ans += resp.choices[0].delta.content + total_tokens += 1 + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + + except openai.APIError as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): @@ -97,6 +121,35 @@ def chat(self, system, history, gen_conf): return "**ERROR**: " + response.message, tk_count + def chat_streamly(self, system, history, gen_conf): + from http import HTTPStatus + if system: + history.insert(0, {"role": "system", "content": system}) + ans = "" + try: + response = Generation.call( + self.model_name, + messages=history, + result_format='message', + stream=True, + **gen_conf + ) + tk_count = 0 + for resp in response: + if resp.status_code == HTTPStatus.OK: + ans = resp.output.choices[0]['message']['content'] + tk_count = resp.usage.total_tokens + if resp.output.choices[0].get("finish_reason", "") == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + else: + yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**" + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + class ZhipuChat(Base): def __init__(self, key, model_name="glm-3-turbo", **kwargs): @@ -122,6 +175,34 @@ def chat(self, system, history, gen_conf): except Exception as e: return "**ERROR**: " + str(e), 0 + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] + ans = "" + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + stream=True, + **gen_conf + ) + tk_count = 0 + for resp in response: + if not resp.choices[0].delta.content:continue + delta = resp.choices[0].delta.content + ans += delta + tk_count = resp.usage.total_tokens if response.usage else 0 + if resp.output.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + class OllamaChat(Base): def __init__(self, key, model_name, **kwargs): @@ -148,3 +229,28 @@ def chat(self, system, history, gen_conf): except Exception as e: return "**ERROR**: " + str(e), 0 + def chat_streamly(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + options = {} + if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] + if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] + if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] + if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] + ans = "" + try: + response = self.client.chat( + model=self.model_name, + messages=history, + stream=True, + options=options + ) + for resp in response: + if resp["done"]: + return resp["prompt_eval_count"] + resp["eval_count"] + ans = resp["message"]["content"] + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + yield 0 diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 700a30aae6..413805a8f3 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -80,7 +80,7 @@ def set_progress(task_id, from_page=0, to_page=-1, if to_page > 0: if msg: - msg = f"Page({from_page+1}~{to_page+1}): " + msg + msg = f"Page({from_page + 1}~{to_page + 1}): " + msg d = {"progress_msg": msg} if prog is not None: d["progress"] = prog @@ -124,7 +124,7 @@ def get_minio_binary(bucket, name): def build(row): if row["size"] > DOC_MAXIMUM_SIZE: set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" % - (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) + (int(DOC_MAXIMUM_SIZE / 1024 / 1024))) return [] callback = partial( @@ -138,12 +138,12 @@ def build(row): bucket, name = File2DocumentService.get_minio_address(doc_id=row["doc_id"]) binary = get_minio_binary(bucket, name) cron_logger.info( - "From minio({}) {}/{}".format(timer()-st, row["location"], row["name"])) + "From minio({}) {}/{}".format(timer() - st, row["location"], row["name"])) cks = chunker.chunk(row["name"], binary=binary, from_page=row["from_page"], to_page=row["to_page"], lang=row["language"], callback=callback, kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"]) cron_logger.info( - "Chunkking({}) {}/{}".format(timer()-st, row["location"], row["name"])) + "Chunkking({}) {}/{}".format(timer() - st, row["location"], row["name"])) except TimeoutError as e: callback(-1, f"Internal server error: Fetch file timeout. Could you try it again.") cron_logger.error( @@ -173,7 +173,7 @@ def build(row): d.update(ck) md5 = hashlib.md5() md5.update((ck["content_with_weight"] + - str(d["doc_id"])).encode("utf-8")) + str(d["doc_id"])).encode("utf-8")) d["_id"] = md5.hexdigest() d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.datetime.now().timestamp() @@ -261,7 +261,7 @@ def main(): st = timer() cks = build(r) - cron_logger.info("Build chunks({}): {:.2f}".format(r["name"], timer()-st)) + cron_logger.info("Build chunks({}): {}".format(r["name"], timer() - st)) if cks is None: continue if not cks: @@ -271,7 +271,7 @@ def main(): ## set_progress(r["did"], -1, "ERROR: ") callback( msg="Finished slicing files(%d). Start to embedding the content." % - len(cks)) + len(cks)) st = timer() try: tk_count = embedding(cks, embd_mdl, r["parser_config"], callback) @@ -279,19 +279,19 @@ def main(): callback(-1, "Embedding error:{}".format(str(e))) cron_logger.error(str(e)) tk_count = 0 - cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer()-st)) + cron_logger.info("Embedding elapsed({}): {:.2f}".format(r["name"], timer() - st)) - callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer()-st)) + callback(msg="Finished embedding({:.2f})! Start to build index!".format(timer() - st)) init_kb(r) chunk_count = len(set([c["_id"] for c in cks])) st = timer() es_r = "" for b in range(0, len(cks), 32): - es_r = ELASTICSEARCH.bulk(cks[b:b+32], search.index_name(r["tenant_id"])) + es_r = ELASTICSEARCH.bulk(cks[b:b + 32], search.index_name(r["tenant_id"])) if b % 128 == 0: callback(prog=0.8 + 0.1 * (b + 1) / len(cks), msg="") - cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer()-st)) + cron_logger.info("Indexing elapsed({}): {:.2f}".format(r["name"], timer() - st)) if es_r: callback(-1, "Index failure!") ELASTICSEARCH.deleteByQuery( @@ -307,8 +307,7 @@ def main(): r["doc_id"], r["kb_id"], tk_count, chunk_count, 0) cron_logger.info( "Chunk doc({}), token({}), chunks({}), elapsed:{:.2f}".format( - r["id"], tk_count, len(cks), timer()-st)) - + r["id"], tk_count, len(cks), timer() - st)) if __name__ == "__main__": diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 87348c6aab..78f2346056 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -43,6 +43,9 @@ def version(self): v = v["number"].split(".")[0] return int(v) >= 7 + def health(self): + return dict(self.es.cluster.health()) + def upsert(self, df, idxnm=""): res = [] for d in df: diff --git a/rag/utils/minio_conn.py b/rag/utils/minio_conn.py index fa87ed3b06..f6aff1bae4 100644 --- a/rag/utils/minio_conn.py +++ b/rag/utils/minio_conn.py @@ -34,6 +34,16 @@ def __close__(self): del self.conn self.conn = None + def health(self): + bucket, fnm, binary = "_t@@@1", "_t@@@1", b"_t@@@1" + if not self.conn.bucket_exists(bucket): + self.conn.make_bucket(bucket) + r = self.conn.put_object(bucket, fnm, + BytesIO(binary), + len(binary) + ) + return r + def put(self, bucket, fnm, binary): for _ in range(3): try: diff --git a/rag/utils/redis_conn.py b/rag/utils/redis_conn.py index 2f06ae92aa..1fa0604b8b 100644 --- a/rag/utils/redis_conn.py +++ b/rag/utils/redis_conn.py @@ -44,6 +44,10 @@ def __open__(self): logging.warning("Redis can't be connected.") return self.REDIS + def health(self, queue_name): + self.REDIS.ping() + return self.REDIS.xinfo_groups(queue_name)[0] + def is_alive(self): return self.REDIS is not None