6
6
@date:2024/1/9 17:40
7
7
@desc:
8
8
"""
9
+ import concurrent
9
10
import json
10
11
import threading
11
12
import traceback
12
13
from concurrent .futures import ThreadPoolExecutor
13
14
from functools import reduce
14
15
from typing import List , Dict
15
16
17
+ from django .db import close_old_connections
16
18
from django .db .models import QuerySet
17
19
from langchain_core .prompts import PromptTemplate
18
20
from rest_framework import status
@@ -223,23 +225,6 @@ def pop(self):
223
225
return None
224
226
225
227
226
- class NodeChunk :
227
- def __init__ (self ):
228
- self .status = 0
229
- self .chunk_list = []
230
-
231
- def add_chunk (self , chunk ):
232
- self .chunk_list .append (chunk )
233
-
234
- def end (self , chunk = None ):
235
- if chunk is not None :
236
- self .add_chunk (chunk )
237
- self .status = 200
238
-
239
- def is_end (self ):
240
- return self .status == 200
241
-
242
-
243
228
class WorkflowManage :
244
229
def __init__ (self , flow : Flow , params , work_flow_post_handler : WorkFlowPostHandler ,
245
230
base_to_response : BaseToResponse = SystemToResponse (), form_data = None , image_list = None ,
@@ -273,8 +258,9 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
273
258
self .status = 200
274
259
self .base_to_response = base_to_response
275
260
self .chat_record = chat_record
276
- self .await_future_map = {}
277
261
self .child_node = child_node
262
+ self .future_list = []
263
+ self .lock = threading .Lock ()
278
264
if start_node_id is not None :
279
265
self .load_node (chat_record , start_node_id , start_node_data )
280
266
else :
@@ -319,6 +305,7 @@ def get_node_params(n):
319
305
self .node_context .append (node )
320
306
321
307
def run (self ):
308
+ close_old_connections ()
322
309
if self .params .get ('stream' ):
323
310
return self .run_stream (self .start_node , None )
324
311
return self .run_block ()
@@ -328,8 +315,9 @@ def run_block(self):
328
315
非流式响应
329
316
@return: 结果
330
317
"""
331
- result = self .run_chain_async (None , None )
332
- result .result ()
318
+ self .run_chain_async (None , None )
319
+ while self .is_run ():
320
+ pass
333
321
details = self .get_runtime_details ()
334
322
message_tokens = sum ([row .get ('message_tokens' ) for row in details .values () if
335
323
'message_tokens' in row and row .get ('message_tokens' ) is not None ])
@@ -350,12 +338,22 @@ def run_stream(self, current_node, node_result_future):
350
338
流式响应
351
339
@return:
352
340
"""
353
- result = self .run_chain_async (current_node , node_result_future )
354
- return tools .to_stream_response_simple (self .await_result (result ))
341
+ self .run_chain_async (current_node , node_result_future )
342
+ return tools .to_stream_response_simple (self .await_result ())
355
343
356
- def await_result (self , result ):
344
+ def is_run (self , timeout = 0.1 ):
345
+ self .lock .acquire ()
357
346
try :
358
- while await_result (result ):
347
+ r = concurrent .futures .wait (self .future_list , timeout )
348
+ return len (r .not_done ) > 0
349
+ except Exception as e :
350
+ return True
351
+ finally :
352
+ self .lock .release ()
353
+
354
+ def await_result (self ):
355
+ try :
356
+ while self .is_run ():
359
357
while True :
360
358
chunk = self .node_chunk_manage .pop ()
361
359
if chunk is not None :
@@ -383,42 +381,39 @@ def await_result(self, result):
383
381
'' , True , message_tokens , answer_tokens , {})
384
382
385
383
def run_chain_async (self , current_node , node_result_future ):
386
- return executor .submit (self .run_chain_manage , current_node , node_result_future )
384
+ future = executor .submit (self .run_chain_manage , current_node , node_result_future )
385
+ self .future_list .append (future )
387
386
388
387
def run_chain_manage (self , current_node , node_result_future ):
389
388
if current_node is None :
390
389
start_node = self .get_start_node ()
391
390
current_node = get_node (start_node .type )(start_node , self .params , self )
391
+ self .node_chunk_manage .add_node_chunk (current_node .node_chunk )
392
+ # 添加节点
393
+ self .append_node (current_node )
392
394
result = self .run_chain (current_node , node_result_future )
393
395
if result is None :
394
396
return
395
397
node_list = self .get_next_node_list (current_node , result )
396
398
if len (node_list ) == 1 :
397
399
self .run_chain_manage (node_list [0 ], None )
398
400
elif len (node_list ) > 1 :
399
-
401
+ sorted_node_run_list = sorted ( node_list , key = lambda n : n . node . y )
400
402
# 获取到可执行的子节点
401
403
result_list = [{'node' : node , 'future' : executor .submit (self .run_chain_manage , node , None )} for node in
402
- node_list ]
403
- self .set_await_map (result_list )
404
- [r .get ('future' ).result () for r in result_list ]
405
-
406
- def set_await_map (self , node_run_list ):
407
- sorted_node_run_list = sorted (node_run_list , key = lambda n : n .get ('node' ).node .y )
408
- for index in range (len (sorted_node_run_list )):
409
- self .await_future_map [sorted_node_run_list [index ].get ('node' ).runtime_node_id ] = [
410
- sorted_node_run_list [i ].get ('future' )
411
- for i in range (index )]
404
+ sorted_node_run_list ]
405
+ try :
406
+ self .lock .acquire ()
407
+ for r in result_list :
408
+ self .future_list .append (r .get ('future' ))
409
+ finally :
410
+ self .lock .release ()
412
411
413
412
def run_chain (self , current_node , node_result_future = None ):
414
413
if node_result_future is None :
415
414
node_result_future = self .run_node_future (current_node )
416
415
try :
417
416
is_stream = self .params .get ('stream' , True )
418
- # 处理节点响应
419
- await_future_list = self .await_future_map .get (current_node .runtime_node_id , None )
420
- if await_future_list is not None :
421
- [f .result () for f in await_future_list ]
422
417
result = self .hand_event_node_result (current_node ,
423
418
node_result_future ) if is_stream else self .hand_node_result (
424
419
current_node , node_result_future )
@@ -434,16 +429,14 @@ def hand_node_result(self, current_node, node_result_future):
434
429
if result is not None :
435
430
# 阻塞获取结果
436
431
list (result )
437
- # 添加节点
438
- self .node_context .append (current_node )
439
432
return current_result
440
433
except Exception as e :
441
- # 添加节点
442
- self .node_context .append (current_node )
443
434
traceback .print_exc ()
444
435
self .status = 500
445
436
current_node .get_write_error_context (e )
446
437
self .answer += str (e )
438
+ finally :
439
+ current_node .node_chunk .end ()
447
440
448
441
def append_node (self , current_node ):
449
442
for index in range (len (self .node_context )):
@@ -454,15 +447,14 @@ def append_node(self, current_node):
454
447
self .node_context .append (current_node )
455
448
456
449
def hand_event_node_result (self , current_node , node_result_future ):
457
- node_chunk = NodeChunk ()
458
450
real_node_id = current_node .runtime_node_id
459
451
child_node = {}
452
+ view_type = current_node .view_type
460
453
try :
461
454
current_result = node_result_future .result ()
462
455
result = current_result .write_context (current_node , self )
463
456
if result is not None :
464
457
if self .is_result (current_node , current_result ):
465
- self .node_chunk_manage .add_node_chunk (node_chunk )
466
458
for r in result :
467
459
content = r
468
460
child_node = {}
@@ -487,26 +479,24 @@ def hand_event_node_result(self, current_node, node_result_future):
487
479
'child_node' : child_node ,
488
480
'node_is_end' : node_is_end ,
489
481
'real_node_id' : real_node_id })
490
- node_chunk .add_chunk (chunk )
491
- chunk = self .base_to_response .to_stream_chunk_response (self .params ['chat_id' ],
492
- self .params ['chat_record_id' ],
493
- current_node .id ,
494
- current_node .up_node_id_list ,
495
- '' , False , 0 , 0 , {'node_is_end' : True ,
496
- 'runtime_node_id' : current_node .runtime_node_id ,
497
- 'node_type' : current_node .type ,
498
- 'view_type' : view_type ,
499
- 'child_node' : child_node ,
500
- 'real_node_id' : real_node_id })
501
- node_chunk .end (chunk )
482
+ current_node .node_chunk .add_chunk (chunk )
483
+ chunk = (self .base_to_response
484
+ .to_stream_chunk_response (self .params ['chat_id' ],
485
+ self .params ['chat_record_id' ],
486
+ current_node .id ,
487
+ current_node .up_node_id_list ,
488
+ '' , False , 0 , 0 , {'node_is_end' : True ,
489
+ 'runtime_node_id' : current_node .runtime_node_id ,
490
+ 'node_type' : current_node .type ,
491
+ 'view_type' : view_type ,
492
+ 'child_node' : child_node ,
493
+ 'real_node_id' : real_node_id }))
494
+ current_node .node_chunk .add_chunk (chunk )
502
495
else :
503
496
list (result )
504
- # 添加节点
505
- self .append_node (current_node )
506
497
return current_result
507
498
except Exception as e :
508
499
# 添加节点
509
- self .append_node (current_node )
510
500
traceback .print_exc ()
511
501
chunk = self .base_to_response .to_stream_chunk_response (self .params ['chat_id' ],
512
502
self .params ['chat_record_id' ],
@@ -519,12 +509,12 @@ def hand_event_node_result(self, current_node, node_result_future):
519
509
'view_type' : current_node .view_type ,
520
510
'child_node' : {},
521
511
'real_node_id' : real_node_id })
522
- if not self .node_chunk_manage .contains (node_chunk ):
523
- self .node_chunk_manage .add_node_chunk (node_chunk )
524
- node_chunk .end (chunk )
512
+ current_node .node_chunk .add_chunk (chunk )
525
513
current_node .get_write_error_context (e )
526
514
self .status = 500
527
515
return None
516
+ finally :
517
+ current_node .node_chunk .end ()
528
518
529
519
def run_node_async (self , node ):
530
520
future = executor .submit (self .run_node , node )
@@ -636,6 +626,8 @@ def get_next_node(self):
636
626
637
627
@staticmethod
638
628
def dependent_node (up_node_id , node ):
629
+ if not node .node_chunk .is_end ():
630
+ return False
639
631
if node .id == up_node_id :
640
632
if node .type == 'form-node' :
641
633
if node .context .get ('form_data' , None ) is not None :
0 commit comments