forked from wenda-LLM/wenda
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwenda.py
137 lines (136 loc) · 4.94 KB
/
wenda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import threading,os,json
import datetime
from bottle import route, response, request,static_file,hook
import bottle
from plugins import settings
LLM=settings.load_LLM()
import torch
if settings.logging:
from plugins.defineSQL import session_maker, 记录
mutex = threading.Lock()
@route('/static/<path:path>')
def staticjs(path='-'):
return static_file(path, root="views/static/")
@route('/:name')
def static(name='-'):
return static_file(name, root="views")
@route('/')
def index():
response.set_header( "Pragma", "no-cache" );
response.add_header( "Cache-Control", "must-revalidate" );
response.add_header( "Cache-Control", "no-cache" );
response.add_header( "Cache-Control", "no-store" );
return static_file("index.html", root="views")
当前用户=None
@route('/api/chat_now', method='GET')
def api_chat_now():
return '当前状态:'+当前用户[0]
@hook('before_request')
def validate():
REQUEST_METHOD = request.environ.get('REQUEST_METHOD')
HTTP_ACCESS_CONTROL_REQUEST_METHOD = request.environ.get('HTTP_ACCESS_CONTROL_REQUEST_METHOD')
if REQUEST_METHOD == 'OPTIONS' and HTTP_ACCESS_CONTROL_REQUEST_METHOD:
request.environ['REQUEST_METHOD'] = HTTP_ACCESS_CONTROL_REQUEST_METHOD
@route('/api/save_news', method='OPTIONS')
@route('/api/save_news', method='POST')
def api_chat_stream():
response.set_header('Access-Control-Allow-Origin', '*')
response.add_header('Access-Control-Allow-Methods','POST,OPTIONS')
response.add_header('Access-Control-Allow-Headers',
'Origin, Accept, Content-Type, X-Requested-With, X-CSRF-Token')
try:
data = request.json
if not data:return '0'
title = data.get('title')
txt = data.get('txt')
cut_file=f"txt/{title}.txt"
with open(cut_file, 'w',encoding='utf-8') as f:
f.write(txt)
f.close()
return '1'
except Exception as e:
print(e)
return '2'
@route('/api/find', method='POST')
def api_find():
data = request.json
prompt = data.get('prompt')
return json.dumps(zhishiku.find(prompt))
@route('/api/chat_stream', method='POST')
def api_chat_stream():
data = request.json
prompt = data.get('prompt')
keyword = data.get('keyword')
max_length = data.get('max_length')
if max_length is None:
max_length = 2048
top_p = data.get('top_p')
if top_p is None:
top_p = 0.7
temperature = data.get('temperature')
if temperature is None:
temperature = 0.9
use_zhishiku = data.get('zhishiku')
if use_zhishiku is None:
use_zhishiku = False
history = data.get('history')
history_formatted = LLM.chat_init(history)
response=''
# print(request.environ)
IP=request.environ.get('HTTP_X_REAL_IP') or request.environ.get('REMOTE_ADDR')
global 当前用户
error=""
with mutex:
footer='///'
yield str(len(prompt))+'字正在计算'
if use_zhishiku:
if keyword is None:
keyword=prompt
# print(keyword)
response_d=zhishiku.find(keyword)
torch.cuda.empty_cache()
output_sources = [i['title'] for i in response_d]
results ='\n---\n'.join([i['content'] for i in response_d])
prompt= 'system:学习以下文段, 用中文回答用户问题。如果无法从中得到答案,忽略文段内容并用中文回答用户问题。\n\n'+results+'\nuser:'+prompt
footer= "\n来源:\n"+('\n').join(output_sources)+'///'
yield footer
print( "\033[1;32m"+IP+":\033[1;31m"+prompt+"\033[1;37m")
try:
for response in LLM.chat_one(prompt,history_formatted,max_length,top_p,temperature,zhishiku=use_zhishiku):
if(response):yield response+footer
except Exception as e:
error=str(e)
print("错误",settings.red,error,settings.white,e)
response=''
torch.cuda.empty_cache()
if response=='':
yield "发生错误,正在重新加载模型"+error+'///'
os._exit(0)
if settings.logging:
with session_maker() as session:
jl = 记录(时间=datetime.datetime.now(),IP=IP,问= prompt,答=response)
session.add(jl)
session.commit()
print(response)
yield "/././"
model=None
tokenizer=None
def load_model():
global 当前用户
mutex.acquire()
当前用户=['模型加载中','','']
LLM.load_model()
mutex.release()
torch.cuda.empty_cache()
print(settings.green,"模型加载完成",settings.white)
thread_load_model = threading.Thread(target=load_model)
thread_load_model.start()
zhishiku=None
def load_zsk():
global zhishiku
zhishiku=settings.load_zsk()
print(settings.green,"知识库加载完成",settings.white)
thread_load_zsk = threading.Thread(target=load_zsk)
thread_load_zsk.start()
bottle.debug(True)
bottle.run(server='paste',host="0.0.0.0",port=settings.port,quiet=True)