Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAGFlow streaming output suggestions #3738 #3881

Open
wants to merge 102 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
b071797
Add test for document (#3497)
Feiue Nov 19, 2024
b1001bf
fix: laws.py added missing import logging (#3501)
michalmasrna1 Nov 19, 2024
f424f19
Fix bugs (#3502)
JinHai-CN Nov 20, 2024
36e75b3
fix synonym bug (#3506)
KevinHuSh Nov 20, 2024
e16b7c5
smooth term weight (#3510)
KevinHuSh Nov 20, 2024
81f92d0
feat: Add Datasets component to home page #3221 (#3508)
cike8899 Nov 20, 2024
9314b03
fix: keyerror issue (#3512)
KevinHuSh Nov 20, 2024
95dc59d
Added kb_id filter to knn. Fix #3458 (#3513)
yuzhichang Nov 20, 2024
2062d7f
Make spark model robuster to model name (#3514)
KevinHuSh Nov 20, 2024
273678d
Fix: potential risk (#3515)
KevinHuSh Nov 20, 2024
c55231b
Fix set_output type hint (#3516)
yuzhichang Nov 20, 2024
cbef6fd
Merge remote-tracking branch 'remote/main'
Nov 21, 2024
f5ef1fb
Merge remote-tracking branch 'remote/main'
Nov 21, 2024
f4a7b92
Merge remote-tracking branch 'remote/main'
Nov 21, 2024
6976db1
Merge remote-tracking branch 'remote/main'
Nov 22, 2024
b8c31d5
Merge remote-tracking branch 'remote/main'
Nov 22, 2024
e9140ae
Merge remote-tracking branch 'remote/main'
Nov 22, 2024
70359e0
Merge remote-tracking branch 'remote/main'
Nov 25, 2024
22b0ad9
Merge remote-tracking branch 'remote/main'
Nov 25, 2024
ccb4e2f
Merge remote-tracking branch 'remote/main'
Nov 25, 2024
3f3e073
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
accbe5f
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
facc2d6
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
d030a23
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
f0c7e25
Merge remote-tracking branch 'remote/main'
Nov 26, 2024
32f6517
Merge remote-tracking branch 'remote/main'
Nov 27, 2024
1211e22
Merge remote-tracking branch 'remote/main'
Nov 27, 2024
7934014
Merge remote-tracking branch 'remote/main'
Nov 28, 2024
49ad2bd
Merge remote-tracking branch 'remote/main'
Nov 28, 2024
00c1b41
Merge remote-tracking branch 'remote/main'
Nov 29, 2024
dd9fec8
Merge remote-tracking branch 'remote/main'
Nov 29, 2024
202ada4
Merge remote-tracking branch 'remote/main'
Dec 3, 2024
851ad89
Merge remote-tracking branch 'remote/main'
Dec 3, 2024
39ab46e
Merge remote-tracking branch 'remote/main'
Dec 3, 2024
c61bd86
Merge remote-tracking branch 'remote/main'
Dec 4, 2024
5dfc60d
Merge remote-tracking branch 'remote/main'
Dec 4, 2024
2160487
test: add session.py logs
Dec 4, 2024
c0dfab5
test: add dialog_service.py logs
Dec 4, 2024
742a871
test: chat_streamly stream
Dec 4, 2024
121d78a
Merge remote-tracking branch 'remote/main' into main_lz
Dec 4, 2024
762dea0
Merge remote-tracking branch 'remote/main'
Dec 4, 2024
d932105
Merge remote-tracking branch 'remote/main' into main_lz
Dec 4, 2024
c964b36
Merge remote-tracking branch 'remote/main' into main_lz
Dec 5, 2024
e51fcf8
Merge remote-tracking branch 'refs/remotes/remote/main' into main_lz
Dec 5, 2024
d090b1c
test: chat_streamly delta
Dec 5, 2024
3570ce4
Merge remote-tracking branch 'origin/main_lz'
Dec 5, 2024
30e9c29
Merge remote-tracking branch 'remote/main'
Dec 5, 2024
16cc7ec
Merge remote-tracking branch 'remote/main'
Dec 5, 2024
7c9c42f
Test: Comment log printing
Dec 5, 2024
57375be
Test: delete log printing
Dec 5, 2024
89466ed
Merge remote-tracking branch 'origin/main' into main_remote_lz
Dec 5, 2024
a6d21a1
Merge remote-tracking branch 'remote/main'
Dec 5, 2024
c94dc7f
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 6, 2024
7e4c5fe
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
c26bdf7
Fix: The issue of truncation of the streaming output of the char mode…
Dec 6, 2024
06c8745
Fix: dialog_service.py The issue of truncation of the streaming outpu…
Dec 6, 2024
96dd427
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
144f4e8
Merge remote-tracking branch 'refs/remotes/origin/main' into main_rem…
Dec 6, 2024
47b9c0c
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 6, 2024
df7defc
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
fa86cdf
Merge remote-tracking branch 'remote_lz/main'
Dec 6, 2024
4be5f57
Merge remote-tracking branch 'remote/main'
Dec 6, 2024
67f5ad1
Merge remote-tracking branch 'refs/remotes/remote/main'
Dec 9, 2024
5cf6dc1
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 9, 2024
2c58f3c
Merge remote-tracking branch 'remote_lz/main' into main_remote_lz
Dec 9, 2024
5b0d908
Merge remote-tracking branch 'refs/remotes/origin/main' into main_rem…
Dec 9, 2024
6903acc
Fix: Delete the content of the comment
Dec 9, 2024
3a64029
Fix: Resolve conflicts
Dec 9, 2024
12175ab
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
46f6d28
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
ea93fb7
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
a3667dc
Merge remote-tracking branch 'remote/main'
Dec 9, 2024
602f392
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
ba26f10
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 10, 2024
2937810
Merge remote-tracking branch 'remote_lz/main' into main_remote_lz
Dec 10, 2024
ba1161e
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
b7455a8
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
273fded
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
23aca28
Fix: Remove redundant references
Dec 10, 2024
2a6c252
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 10, 2024
d573164
Merge remote-tracking branch 'origin/main' into main_remote_lz
Dec 10, 2024
eaf622b
Merge remote-tracking branch 'remote/main'
Dec 10, 2024
b55f21e
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 10, 2024
d9d196d
Merge remote-tracking branch 'origin/main' into main_remote_lz
Dec 10, 2024
5d62a80
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 11, 2024
70e9e73
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 11, 2024
cfb877d
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 12, 2024
b5b7397
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 12, 2024
8e8ad6c
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
34a1b3e
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
b020cf8
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
69344cb
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 13, 2024
56dd4f4
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 16, 2024
3fa53d1
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 17, 2024
a0c7211
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 17, 2024
022c45e
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 27, 2024
fdf6a83
Feat: update the chat function to streaming presentation
Dec 27, 2024
c209d09
Fix: Merge code
Dec 27, 2024
d898dff
Fix: exegesis total_tokens
Dec 27, 2024
c3c0f0d
Merge remote-tracking branch 'remote/main' into main_remote_lz
Dec 27, 2024
d1f0ed3
Merge remote-tracking branch 'remote/main' into main_remote_lz
Jan 14, 2025
237a206
Merge remote-tracking branch 'remote/main' into main_remote_lz
Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion api/apps/sdk/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import json
from copy import deepcopy
Expand Down Expand Up @@ -189,7 +190,7 @@ def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, **req):
print("ans:", ans)
#logging.info("ans : {}".format(ans))
fillin_conv(ans)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
Expand Down
86 changes: 73 additions & 13 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api import settings
from rag.app.resume import forbidden_select_fields4resume
from rag.llm import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
Expand Down Expand Up @@ -273,21 +274,80 @@ def decorate_answer(answer):
(done_tm - retrieval_tm) * 1000)
return {"answer": answer, "reference": refs, "prompt": prompt}

# if stream:
# last_ans = ""
# answer = ""
# for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
# answer = ans
# logging.info("answer_stream : {}".format(ans))
# delta_ans = ans[len(last_ans):]
# if num_tokens_from_string(delta_ans) < 16:
# continue
# last_ans = answer
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# delta_ans = answer[len(last_ans):]
# if delta_ans:
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# yield decorate_answer(answer)

if stream:
last_ans = ""
logging.info("stream_mode : {}".format(msg[1:]))
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans
logging.info("answer_stream : {}".format(ans))
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
yield decorate_answer(answer)
for delta in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
# 检查是否为总令牌数或通知信息
if isinstance(delta, str):
if delta.isdigit():
# 处理总令牌数(如果需要)
total_tokens = int(delta)
# logging.info(f"Total tokens used: {total_tokens}")
continue
elif delta in [LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN]:
# 处理长度通知信息
answer += delta
# logging.info(f"Length notification: {delta}")
audio = tts(tts_mdl, delta)
yield {"answer": answer, "reference": {}, "audio_binary": audio}
continue
elif "\n**ERROR**:" in delta:
# 处理错误信息
answer += delta
# logging.error(f"Error in response: {delta}")
yield {"answer": answer, "reference": {}, "audio_binary": b''} # 错误时不生成音频
continue

# 处理增量文本
delta_ans = delta
# if num_tokens_from_string(delta_ans) < 16:
# continue # 根据需求调整阈值

# 更新完整的答案
answer += delta_ans

# 生成音频
audio = tts(tts_mdl, delta_ans)
# logging.info(f"Generated audio for delta: {delta_ans}")
yield {"answer": delta_ans, "reference": {}, "audio_binary": audio}
elif isinstance(delta, dict):
# 如果 chat_streamly 仍返回字典(不推荐)
# 例如: {"new_text": "新增内容", "position": 10}
new_text = delta.get("new_text", "")
if not new_text:
continue
if num_tokens_from_string(new_text) < 16:
continue

# 更新完整的答案
answer += new_text

# 生成音频
audio = tts(tts_mdl, new_text)
logging.info(f"Generated audio for new_text: {new_text}")
yield {"answer": answer, "reference": {}, "audio_binary": audio}

# 最终装饰答案
decorated_answer = decorate_answer(answer)
logging.info(f"Final decorated answer: {decorated_answer}")
yield decorated_answer
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
logging.debug("User: {}|Assistant: {}".format(
Expand Down
4 changes: 3 additions & 1 deletion poetry.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[virtualenvs]
in-project = true
create = true
prefer-active-python = true
prefer-active-python = true
[repositories.tuna]
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
78 changes: 66 additions & 12 deletions rag/llm/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re

from openai.lib.azure import AzureOpenAI
Expand Down Expand Up @@ -57,42 +58,95 @@ def chat(self, system, history, gen_conf):
except openai.APIError as e:
return "**ERROR**: " + str(e), 0

# def chat_streamly(self, system, history, gen_conf):
# if system:
# history.insert(0, {"role": "system", "content": system})
# ans = ""
# total_tokens = 0
# try:
# response = self.client.chat.completions.create(
# model=self.model_name,
# messages=history,
# stream=True,
# **gen_conf)
# for resp in response:
# if not resp.choices: continue
# if not resp.choices[0].delta.content:
# resp.choices[0].delta.content = ""
# ans += resp.choices[0].delta.content
#
# if not hasattr(resp, "usage") or not resp.usage:
# total_tokens = (
# total_tokens
# + num_tokens_from_string(resp.choices[0].delta.content)
# )
# elif isinstance(resp.usage, dict):
# total_tokens = resp.usage.get("total_tokens", total_tokens)
# else: total_tokens = resp.usage.total_tokens
#
# if resp.choices[0].finish_reason == "length":
# if is_chinese(ans):
# ans += LENGTH_NOTIFICATION_CN
# else:
# ans += LENGTH_NOTIFICATION_EN
# yield ans
#
# except openai.APIError as e:
# yield ans + "\n**ERROR**: " + str(e)
#
# yield total_tokens

def chat_streamly(self, system, history, gen_conf):
logging.info("lizheng_test: chat_streamly")
if system:
history.insert(0, {"role": "system", "content": system})

ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
**gen_conf)
**gen_conf
)
for resp in response:
if not resp.choices: continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
if not resp.choices:
continue
# 获取delta内容,确保其为字符串
delta_content = resp.choices[0].delta.content or ""
if not delta_content:
continue
# 累积答案
ans += delta_content

# 更新令牌计数
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
total_tokens
+ num_tokens_from_string(delta_content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
else:
total_tokens = resp.usage.total_tokens

# 仅返回新增的部分
yield delta_content

# 处理完成原因
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
notification = LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
notification = LENGTH_NOTIFICATION_EN
yield notification

except openai.APIError as e:
# 返回错误信息
yield ans + "\n**ERROR**: " + str(e)

# 返回总令牌数
yield total_tokens


Expand Down
66 changes: 61 additions & 5 deletions rag/llm/cv_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,48 @@ def chat(self, system, history, gen_conf, image=""):
except Exception as e:
return "**ERROR**: " + str(e), 0

# def chat_streamly(self, system, history, gen_conf, image=""):
# if system:
# history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
#
# ans = ""
# tk_count = 0
# try:
# for his in history:
# if his["role"] == "user":
# his["content"] = self.chat_prompt(his["content"], image)
#
# response = self.client.chat.completions.create(
# model=self.model_name,
# messages=history,
# max_tokens=gen_conf.get("max_tokens", 1000),
# temperature=gen_conf.get("temperature", 0.3),
# top_p=gen_conf.get("top_p", 0.7),
# stream=True
# )
# for resp in response:
# if not resp.choices[0].delta.content: continue
# delta = resp.choices[0].delta.content
# ans += delta
# if resp.choices[0].finish_reason == "length":
# ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
# [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
# tk_count = resp.usage.total_tokens
# if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
# yield ans
# except Exception as e:
# yield ans + "\n**ERROR**: " + str(e)
#
# yield tk_count

def chat_streamly(self, system, history, gen_conf, image=""):
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]

ans = ""
tk_count = 0
last_sent_length = 0 # 跟踪上一次发送的内容长度

try:
for his in history:
if his["role"] == "user":
Expand All @@ -77,19 +113,39 @@ def chat_streamly(self, system, history, gen_conf, image=""):
stream=True
)
for resp in response:
if not resp.choices[0].delta.content: continue
if not resp.choices[0].delta.content:
continue
delta = resp.choices[0].delta.content
ans += delta

# 计算新增部分
new_text = delta
position = last_sent_length # 新增部分的起始位置
last_sent_length += len(new_text)

# 构建增量协议
incremental_update = {
"new_text": new_text,
"position": position
}

if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
message = "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
incremental_update["new_text"] = message
incremental_update["position"] = last_sent_length
tk_count = resp.usage.total_tokens
if resp.choices[0].finish_reason == "stop": tk_count = resp.usage.total_tokens
yield ans

if resp.choices[0].finish_reason == "stop":
tk_count = resp.usage.total_tokens

yield incremental_update

except Exception as e:
yield ans + "\n**ERROR**: " + str(e)
yield {"error": f"**ERROR**: {str(e)}"}

yield tk_count


def image2base64(self, image):
if isinstance(image, bytes):
Expand Down