diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 71139a3a34..2cbcdc32f6 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -38,4 +38,5 @@ authentication: permission: switch: false component: false - dataset: false \ No newline at end of file + dataset: false +task_executor_threads: 4 \ No newline at end of file diff --git a/rag/settings.py b/rag/settings.py index 08c19a815f..2f3a74f707 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -32,6 +32,8 @@ pass DOC_MAXIMUM_SIZE = 128 * 1024 * 1024 +TASK_EXECUTOR_THREADS = get_base_config("task_executor_threads", 4) + # Logger LoggerFactory.set_directory( os.path.join( diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 0351afa51a..03112b4e61 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -21,6 +21,7 @@ import copy import re import sys +import threading import time import traceback from functools import partial @@ -28,7 +29,7 @@ from api.db.services.file2document_service import File2DocumentService from rag.utils.minio_conn import MINIO from api.db.db_models import close_connection -from rag.settings import database_logger, SVR_QUEUE_NAME +from rag.settings import database_logger, SVR_QUEUE_NAME, TASK_EXECUTOR_THREADS from rag.settings import cron_logger, DOC_MAXIMUM_SIZE from multiprocessing import Pool import numpy as np @@ -304,6 +305,11 @@ def main(): r["id"], tk_count, len(cks), timer()-st)) +def worker(thread_number): + cron_logger.info("Task worker run : {}".format(str(thread_number))) + while True: + main() + if __name__ == "__main__": peewee_logger = logging.getLogger('peewee') @@ -311,5 +317,6 @@ def main(): peewee_logger.addHandler(database_logger.handlers[0]) peewee_logger.setLevel(database_logger.level) - while True: - main() + for i in range(TASK_EXECUTOR_THREADS): + t = threading.Thread(target=worker, args=(i,)) + t.start()