Skip to content

Commit efa73c8

Browse files
authored
refactor: Workflow execution logic (#1913)
1 parent bb58ac6 commit efa73c8

File tree

9 files changed

+94
-70
lines changed

9 files changed

+94
-70
lines changed

apps/application/flow/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,20 @@ def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_no
1919
def to_dict(self):
2020
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
2121
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}
22+
23+
24+
class NodeChunk:
25+
def __init__(self):
26+
self.status = 0
27+
self.chunk_list = []
28+
29+
def add_chunk(self, chunk):
30+
self.chunk_list.append(chunk)
31+
32+
def end(self, chunk=None):
33+
if chunk is not None:
34+
self.add_chunk(chunk)
35+
self.status = 200
36+
37+
def is_end(self):
38+
return self.status == 200

apps/application/flow/i_step_node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from rest_framework import serializers
1818
from rest_framework.exceptions import ValidationError, ErrorDetail
1919

20-
from application.flow.common import Answer
20+
from application.flow.common import Answer, NodeChunk
2121
from application.models import ChatRecord
2222
from application.models.api_key_model import ApplicationPublicAccessClient
2323
from common.constants.authentication_type import AuthenticationType
@@ -175,6 +175,7 @@ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
175175
if up_node_id_list is None:
176176
up_node_id_list = []
177177
self.up_node_id_list = up_node_id_list
178+
self.node_chunk = NodeChunk()
178179
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
179180
"".join([*sorted(up_node_id_list),
180181
node.id]))),
@@ -214,6 +215,7 @@ def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:
214215

215216
def get_write_error_context(self, e):
216217
self.status = 500
218+
self.answer_text = str(e)
217219
self.err_message = str(e)
218220
self.context['run_time'] = time.time() - self.context['start_time']
219221

apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import List, Dict
1111

1212
from django.db.models import QuerySet
13-
13+
from django.db import connection
1414
from application.flow.i_step_node import NodeResult
1515
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
1616
from common.config.embedding_config import VectorStore
@@ -77,6 +77,8 @@ def execute(self, dataset_id_list, dataset_setting, question,
7777
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
7878
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
7979
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
80+
# 手动关闭数据库连接
81+
connection.close()
8082
if embedding_list is None:
8183
return get_none_result(question)
8284
paragraph_list = self.list_paragraph(embedding_list, vector)

apps/application/flow/workflow_manage.py

Lines changed: 56 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
@date:2024/1/9 17:40
77
@desc:
88
"""
9+
import concurrent
910
import json
1011
import threading
1112
import traceback
1213
from concurrent.futures import ThreadPoolExecutor
1314
from functools import reduce
1415
from typing import List, Dict
1516

17+
from django.db import close_old_connections
1618
from django.db.models import QuerySet
1719
from langchain_core.prompts import PromptTemplate
1820
from rest_framework import status
@@ -223,23 +225,6 @@ def pop(self):
223225
return None
224226

225227

226-
class NodeChunk:
227-
def __init__(self):
228-
self.status = 0
229-
self.chunk_list = []
230-
231-
def add_chunk(self, chunk):
232-
self.chunk_list.append(chunk)
233-
234-
def end(self, chunk=None):
235-
if chunk is not None:
236-
self.add_chunk(chunk)
237-
self.status = 200
238-
239-
def is_end(self):
240-
return self.status == 200
241-
242-
243228
class WorkflowManage:
244229
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
245230
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
@@ -273,8 +258,9 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
273258
self.status = 200
274259
self.base_to_response = base_to_response
275260
self.chat_record = chat_record
276-
self.await_future_map = {}
277261
self.child_node = child_node
262+
self.future_list = []
263+
self.lock = threading.Lock()
278264
if start_node_id is not None:
279265
self.load_node(chat_record, start_node_id, start_node_data)
280266
else:
@@ -319,6 +305,7 @@ def get_node_params(n):
319305
self.node_context.append(node)
320306

321307
def run(self):
308+
close_old_connections()
322309
if self.params.get('stream'):
323310
return self.run_stream(self.start_node, None)
324311
return self.run_block()
@@ -328,8 +315,9 @@ def run_block(self):
328315
非流式响应
329316
@return: 结果
330317
"""
331-
result = self.run_chain_async(None, None)
332-
result.result()
318+
self.run_chain_async(None, None)
319+
while self.is_run():
320+
pass
333321
details = self.get_runtime_details()
334322
message_tokens = sum([row.get('message_tokens') for row in details.values() if
335323
'message_tokens' in row and row.get('message_tokens') is not None])
@@ -350,12 +338,22 @@ def run_stream(self, current_node, node_result_future):
350338
流式响应
351339
@return:
352340
"""
353-
result = self.run_chain_async(current_node, node_result_future)
354-
return tools.to_stream_response_simple(self.await_result(result))
341+
self.run_chain_async(current_node, node_result_future)
342+
return tools.to_stream_response_simple(self.await_result())
355343

356-
def await_result(self, result):
344+
def is_run(self, timeout=0.1):
345+
self.lock.acquire()
357346
try:
358-
while await_result(result):
347+
r = concurrent.futures.wait(self.future_list, timeout)
348+
return len(r.not_done) > 0
349+
except Exception as e:
350+
return True
351+
finally:
352+
self.lock.release()
353+
354+
def await_result(self):
355+
try:
356+
while self.is_run():
359357
while True:
360358
chunk = self.node_chunk_manage.pop()
361359
if chunk is not None:
@@ -383,42 +381,39 @@ def await_result(self, result):
383381
'', True, message_tokens, answer_tokens, {})
384382

385383
def run_chain_async(self, current_node, node_result_future):
386-
return executor.submit(self.run_chain_manage, current_node, node_result_future)
384+
future = executor.submit(self.run_chain_manage, current_node, node_result_future)
385+
self.future_list.append(future)
387386

388387
def run_chain_manage(self, current_node, node_result_future):
389388
if current_node is None:
390389
start_node = self.get_start_node()
391390
current_node = get_node(start_node.type)(start_node, self.params, self)
391+
self.node_chunk_manage.add_node_chunk(current_node.node_chunk)
392+
# 添加节点
393+
self.append_node(current_node)
392394
result = self.run_chain(current_node, node_result_future)
393395
if result is None:
394396
return
395397
node_list = self.get_next_node_list(current_node, result)
396398
if len(node_list) == 1:
397399
self.run_chain_manage(node_list[0], None)
398400
elif len(node_list) > 1:
399-
401+
sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y)
400402
# 获取到可执行的子节点
401403
result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in
402-
node_list]
403-
self.set_await_map(result_list)
404-
[r.get('future').result() for r in result_list]
405-
406-
def set_await_map(self, node_run_list):
407-
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
408-
for index in range(len(sorted_node_run_list)):
409-
self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [
410-
sorted_node_run_list[i].get('future')
411-
for i in range(index)]
404+
sorted_node_run_list]
405+
try:
406+
self.lock.acquire()
407+
for r in result_list:
408+
self.future_list.append(r.get('future'))
409+
finally:
410+
self.lock.release()
412411

413412
def run_chain(self, current_node, node_result_future=None):
414413
if node_result_future is None:
415414
node_result_future = self.run_node_future(current_node)
416415
try:
417416
is_stream = self.params.get('stream', True)
418-
# 处理节点响应
419-
await_future_list = self.await_future_map.get(current_node.runtime_node_id, None)
420-
if await_future_list is not None:
421-
[f.result() for f in await_future_list]
422417
result = self.hand_event_node_result(current_node,
423418
node_result_future) if is_stream else self.hand_node_result(
424419
current_node, node_result_future)
@@ -434,16 +429,14 @@ def hand_node_result(self, current_node, node_result_future):
434429
if result is not None:
435430
# 阻塞获取结果
436431
list(result)
437-
# 添加节点
438-
self.node_context.append(current_node)
439432
return current_result
440433
except Exception as e:
441-
# 添加节点
442-
self.node_context.append(current_node)
443434
traceback.print_exc()
444435
self.status = 500
445436
current_node.get_write_error_context(e)
446437
self.answer += str(e)
438+
finally:
439+
current_node.node_chunk.end()
447440

448441
def append_node(self, current_node):
449442
for index in range(len(self.node_context)):
@@ -454,15 +447,14 @@ def append_node(self, current_node):
454447
self.node_context.append(current_node)
455448

456449
def hand_event_node_result(self, current_node, node_result_future):
457-
node_chunk = NodeChunk()
458450
real_node_id = current_node.runtime_node_id
459451
child_node = {}
452+
view_type = current_node.view_type
460453
try:
461454
current_result = node_result_future.result()
462455
result = current_result.write_context(current_node, self)
463456
if result is not None:
464457
if self.is_result(current_node, current_result):
465-
self.node_chunk_manage.add_node_chunk(node_chunk)
466458
for r in result:
467459
content = r
468460
child_node = {}
@@ -487,26 +479,24 @@ def hand_event_node_result(self, current_node, node_result_future):
487479
'child_node': child_node,
488480
'node_is_end': node_is_end,
489481
'real_node_id': real_node_id})
490-
node_chunk.add_chunk(chunk)
491-
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
492-
self.params['chat_record_id'],
493-
current_node.id,
494-
current_node.up_node_id_list,
495-
'', False, 0, 0, {'node_is_end': True,
496-
'runtime_node_id': current_node.runtime_node_id,
497-
'node_type': current_node.type,
498-
'view_type': view_type,
499-
'child_node': child_node,
500-
'real_node_id': real_node_id})
501-
node_chunk.end(chunk)
482+
current_node.node_chunk.add_chunk(chunk)
483+
chunk = (self.base_to_response
484+
.to_stream_chunk_response(self.params['chat_id'],
485+
self.params['chat_record_id'],
486+
current_node.id,
487+
current_node.up_node_id_list,
488+
'', False, 0, 0, {'node_is_end': True,
489+
'runtime_node_id': current_node.runtime_node_id,
490+
'node_type': current_node.type,
491+
'view_type': view_type,
492+
'child_node': child_node,
493+
'real_node_id': real_node_id}))
494+
current_node.node_chunk.add_chunk(chunk)
502495
else:
503496
list(result)
504-
# 添加节点
505-
self.append_node(current_node)
506497
return current_result
507498
except Exception as e:
508499
# 添加节点
509-
self.append_node(current_node)
510500
traceback.print_exc()
511501
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
512502
self.params['chat_record_id'],
@@ -519,12 +509,12 @@ def hand_event_node_result(self, current_node, node_result_future):
519509
'view_type': current_node.view_type,
520510
'child_node': {},
521511
'real_node_id': real_node_id})
522-
if not self.node_chunk_manage.contains(node_chunk):
523-
self.node_chunk_manage.add_node_chunk(node_chunk)
524-
node_chunk.end(chunk)
512+
current_node.node_chunk.add_chunk(chunk)
525513
current_node.get_write_error_context(e)
526514
self.status = 500
527515
return None
516+
finally:
517+
current_node.node_chunk.end()
528518

529519
def run_node_async(self, node):
530520
future = executor.submit(self.run_node, node)
@@ -636,6 +626,8 @@ def get_next_node(self):
636626

637627
@staticmethod
638628
def dependent_node(up_node_id, node):
629+
if not node.node_chunk.is_end():
630+
return False
639631
if node.id == up_node_id:
640632
if node.type == 'form-node':
641633
if node.context.get('form_data', None) is not None:

apps/setting/models_provider/tools.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
@date:2024/7/22 11:18
77
@desc:
88
"""
9+
from django.db import connection
910
from django.db.models import QuerySet
1011

1112
from common.config.embedding_config import ModelManage
@@ -15,6 +16,8 @@
1516

1617
def get_model_by_id(_id, user_id):
1718
model = QuerySet(Model).filter(id=_id).first()
19+
# 手动关闭数据库连接
20+
connection.close()
1821
if model is None:
1922
raise Exception("模型不存在")
2023
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):

apps/setting/serializers/model_apply_serializers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
@date:2024/8/20 20:39
77
@desc:
88
"""
9+
from django.db import connection
910
from django.db.models import QuerySet
1011
from langchain_core.documents import Document
1112
from rest_framework import serializers
@@ -18,6 +19,8 @@
1819

1920
def get_embedding_model(model_id):
2021
model = QuerySet(Model).filter(id=model_id).first()
22+
# 手动关闭数据库连接
23+
connection.close()
2124
embedding_model = ModelManage.get_model(model_id,
2225
lambda _id: get_model(model, use_local=True))
2326
return embedding_model

apps/smartdoc/conf.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class Config(dict):
8080
"DB_PORT": 5432,
8181
"DB_USER": "root",
8282
"DB_PASSWORD": "Password123@postgres",
83-
"DB_ENGINE": "django.db.backends.postgresql_psycopg2",
83+
"DB_ENGINE": "dj_db_conn_pool.backends.postgresql",
8484
# 向量模型
8585
"EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese",
8686
"EMBEDDING_DEVICE": "cpu",
@@ -108,7 +108,11 @@ def get_db_setting(self) -> dict:
108108
"PORT": self.get('DB_PORT'),
109109
"USER": self.get('DB_USER'),
110110
"PASSWORD": self.get('DB_PASSWORD'),
111-
"ENGINE": self.get('DB_ENGINE')
111+
"ENGINE": self.get('DB_ENGINE'),
112+
"POOL_OPTIONS": {
113+
"POOL_SIZE": 20,
114+
"MAX_OVERFLOW": 5
115+
}
112116
}
113117

114118
def __init__(self, *args):

installer/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ DB_HOST: 127.0.0.1
1313
DB_PORT: 5432
1414
DB_USER: root
1515
DB_PASSWORD: Password123@postgres
16-
DB_ENGINE: django.db.backends.postgresql_psycopg2
16+
DB_ENGINE: dj_db_conn_pool.backends.postgresql
1717
EMBEDDING_MODEL_PATH: /opt/maxkb/model/embedding
1818
EMBEDDING_MODEL_NAME: /opt/maxkb/model/embedding/shibing624_text2vec-base-chinese
1919

0 commit comments

Comments
 (0)