From f5d5da52ad3aa6af697a5a6bec9a658c9e893ece Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 13 Sep 2024 16:39:45 +0800 Subject: [PATCH 1/2] pl rules --- api/pyproject.toml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 83aa35c542929b..57a3844200bdbe 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -6,6 +6,9 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff] +exclude=[ + "migrations/*", +] line-length = 120 [tool.ruff.lint] @@ -19,6 +22,13 @@ select = [ "I", # isort rules "N", # pep8-naming "PT", # flake8-pytest-style rules + "PLC0208", # iteration-over-set + "PLC2801", # unnecessary-dunder-call + "PLC0414", # useless-import-alias + "PLR0402", # manual-from-import + "PLR1711", # useless-return + "PLR1714", # repeated-equality-comparison + "PLR6201", # literal-membership "RUF019", # unnecessary-key-check "RUF100", # unused-noqa "RUF101", # redirected-noqa @@ -78,9 +88,6 @@ ignore = [ "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name ] -"migrations/versions/*" = [ - "E501", # line-too-long -] "tests/*" = [ "F401", # unused-import "F811", # redefined-while-unused @@ -88,7 +95,6 @@ ignore = [ [tool.ruff.format] exclude = [ - "migrations/**/*", ] [tool.pytest_env] From ddd25a7a03648e24fe7429f1195581a4cc2dccc0 Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 13 Sep 2024 16:40:55 +0800 Subject: [PATCH 2/2] apply fixes --- api/app.py | 2 +- api/commands.py | 4 ++-- api/controllers/console/app/audio.py | 2 +- api/controllers/console/auth/oauth.py | 2 +- .../console/datasets/datasets_document.py | 6 +++--- api/controllers/console/explore/audio.py | 2 +- api/controllers/console/explore/completion.py | 4 ++-- .../console/explore/conversation.py | 10 ++++----- .../console/explore/installed_app.py | 2 +- api/controllers/console/explore/message.py | 4 ++-- api/controllers/console/explore/parameter.py | 2 +- .../console/workspace/workspace.py | 2 +- api/controllers/service_api/app/app.py | 2 +- api/controllers/service_api/app/audio.py | 2 +- api/controllers/service_api/app/completion.py | 4 ++-- .../service_api/app/conversation.py | 6 +++--- api/controllers/service_api/app/message.py | 4 ++-- api/controllers/web/app.py | 2 +- api/controllers/web/audio.py | 2 +- api/controllers/web/completion.py | 4 ++-- api/controllers/web/conversation.py | 10 ++++----- api/controllers/web/message.py | 4 ++-- .../agent/output_parser/cot_output_parser.py | 4 ++-- .../app/app_config/base_app_config_manager.py | 2 +- .../easy_ui_based_app/agent/manager.py | 6 +++--- .../easy_ui_based_app/dataset/manager.py | 2 +- .../easy_ui_based_app/variables/manager.py | 6 +++--- .../features/file_upload/manager.py | 4 ++-- api/core/app/apps/advanced_chat/app_runner.py | 4 ++-- .../base_app_generate_response_converter.py | 2 +- api/core/app/apps/base_app_generator.py | 6 +++--- api/core/app/apps/base_app_queue_manager.py | 4 ++-- .../app/apps/message_based_app_generator.py | 6 +++--- api/core/app/apps/workflow/app_runner.py | 4 ++-- .../annotation_reply/annotation_reply.py | 2 +- .../easy_ui_based_generate_task_pipeline.py | 2 +- .../task_pipeline/workflow_cycle_manage.py | 4 ++-- .../index_tool_callback_handler.py | 2 +- api/core/indexing_runner.py | 2 +- api/core/memory/token_buffer_memory.py | 2 +- .../model_runtime/entities/model_entities.py | 12 +++++------ .../model_providers/anthropic/llm/llm.py | 2 +- .../model_providers/azure_openai/tts/tts.py | 4 ++-- .../model_providers/bedrock/llm/llm.py | 10 ++++----- .../bedrock/text_embedding/text_embedding.py | 8 +++---- .../model_providers/google/llm/llm.py | 4 ++-- .../huggingface_hub/llm/llm.py | 4 ++-- .../huggingface_tei/tei_helper.py | 2 +- .../minimax/llm/chat_completion.py | 4 ++-- .../minimax/llm/chat_completion_pro.py | 4 ++-- .../minimax/text_embedding/text_embedding.py | 2 +- .../model_providers/openai/llm/llm.py | 2 +- .../model_providers/openai/tts/tts.py | 4 ++-- .../model_providers/openrouter/llm/llm.py | 1 - .../model_providers/replicate/llm/llm.py | 2 +- .../text_embedding/text_embedding.py | 4 ++-- .../model_providers/tongyi/llm/llm.py | 8 +++---- .../model_providers/upstage/llm/llm.py | 2 +- .../model_providers/vertex_ai/llm/llm.py | 4 ++-- .../legacy/volc_sdk/base/auth.py | 3 +-- .../model_providers/wenxin/llm/llm.py | 2 +- .../xinference/xinference_helper.py | 2 +- .../model_providers/zhipuai/llm/llm.py | 21 ++++++++----------- .../zhipuai_sdk/types/fine_tuning/__init__.py | 5 ++--- .../schema_validators/common_validator.py | 4 ++-- .../vdb/elasticsearch/elasticsearch_vector.py | 4 ++-- .../datasource/vdb/myscale/myscale_vector.py | 4 ++-- .../rag/datasource/vdb/oracle/oraclevector.py | 10 +-------- api/core/rag/extractor/extract_processor.py | 12 +++++------ .../rag/extractor/firecrawl/firecrawl_app.py | 2 +- api/core/rag/extractor/notion_extractor.py | 4 ++-- api/core/rag/retrieval/dataset_retrieval.py | 2 +- api/core/rag/splitter/text_splitter.py | 2 +- api/core/tools/provider/app_tool_provider.py | 2 +- .../provider/builtin/aippt/tools/aippt.py | 4 ++-- .../builtin/azuredalle/tools/dalle3.py | 4 ++-- .../builtin/code/tools/simple_code.py | 2 +- .../builtin/cogview/tools/cogview3.py | 4 ++-- .../provider/builtin/dalle/tools/dalle3.py | 4 ++-- .../builtin/hap/tools/get_worksheet_fields.py | 4 ++-- .../hap/tools/list_worksheet_records.py | 6 +++--- .../novitaai/tools/novitaai_modelquery.py | 2 +- .../builtin/searchapi/tools/google.py | 2 +- .../builtin/searchapi/tools/google_jobs.py | 2 +- .../builtin/searchapi/tools/google_news.py | 2 +- .../searchapi/tools/youtube_transcripts.py | 2 +- .../provider/builtin/spider/spiderApp.py | 2 +- .../builtin/stability/tools/text2image.py | 2 +- .../provider/builtin/vanna/tools/vanna.py | 2 +- .../tools/provider/builtin_tool_provider.py | 2 +- api/core/tools/provider/tool_provider.py | 18 ++++++++-------- api/core/tools/tool/api_tool.py | 8 +++---- api/core/tools/tool_engine.py | 12 +++-------- api/core/tools/utils/message_transformer.py | 2 +- api/core/tools/utils/parser.py | 2 +- api/core/tools/utils/web_reader_tool.py | 2 +- .../entities/runtime_route_state.py | 2 +- .../answer/answer_stream_generate_router.py | 4 ++-- .../nodes/end/end_stream_generate_router.py | 4 ++-- .../nodes/http_request/http_executor.py | 8 +++---- .../nodes/parameter_extractor/entities.py | 4 ++-- .../parameter_extractor_node.py | 10 ++++----- api/core/workflow/nodes/tool/tool_node.py | 5 +---- api/libs/oauth_data_source.py | 4 ++-- api/libs/rsa.py | 2 +- api/models/dataset.py | 6 +++--- api/models/model.py | 6 +++--- api/services/account_service.py | 6 +++--- api/services/app_dsl_service.py | 8 +++---- api/services/app_service.py | 2 +- api/services/audio_service.py | 4 ++-- api/services/auth/firecrawl.py | 2 +- api/services/dataset_service.py | 2 +- api/services/tools/tools_transform_service.py | 2 +- api/services/workflow_service.py | 2 +- api/tasks/recover_document_indexing_task.py | 2 +- .../model_runtime/__mock/google.py | 4 +--- .../model_runtime/__mock/openai_chat.py | 4 ++-- .../model_runtime/__mock/openai_completion.py | 2 +- .../model_runtime/__mock/openai_embeddings.py | 2 +- .../model_runtime/__mock/openai_moderation.py | 2 +- .../__mock/openai_speech2text.py | 2 +- .../model_runtime/__mock/xinference.py | 4 ++-- .../nodes/test_parameter_extractor.py | 2 +- .../graph_engine/test_graph_engine.py | 6 +++--- 125 files changed, 243 insertions(+), 268 deletions(-) diff --git a/api/app.py b/api/app.py index ad219ca0d67459..91a49337fccbde 100644 --- a/api/app.py +++ b/api/app.py @@ -164,7 +164,7 @@ def initialize_extensions(app): @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in ["console", "inner_api"]: + if request.blueprint not in {"console", "inner_api"}: return None # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get("Authorization", "") diff --git a/api/commands.py b/api/commands.py index 887270b43e6778..3a6b4963cfa681 100644 --- a/api/commands.py +++ b/api/commands.py @@ -140,9 +140,9 @@ def reset_encrypt_key_pair(): @click.command("vdb-migrate", help="migrate vector db.") @click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") def vdb_migrate(scope: str): - if scope in ["knowledge", "all"]: + if scope in {"knowledge", "all"}: migrate_knowledge_vector_database() - if scope in ["annotation", "all"]: + if scope in {"annotation", "all"}: migrate_annotation_vector_database() diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 7332758e83c60a..c1ef05a488635c 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -94,7 +94,7 @@ def post(self, app_model): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 1df0f5de9d5e8f..ad0c0580aeaba4 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -71,7 +71,7 @@ def get(self, provider: str): account = _generate_account(provider, user_info) # Check account status - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 076f3cd44d5af5..829ef11e521d8e 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -354,7 +354,7 @@ def get(self, dataset_id, document_id): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule @@ -421,7 +421,7 @@ def get(self, dataset_id, batch): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ["completed", "error"]: + if document.indexing_status in {"completed", "error"}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info @@ -665,7 +665,7 @@ def patch(self, dataset_id, document_id, action): db.session.commit() elif action == "resume": - if document.indexing_status not in ["paused", "error"]: + if document.indexing_status not in {"paused", "error"}: raise InvalidActionError("Document not in paused or error state.") document.paused_by = None diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 2eb7e0449037d1..9690677f61b1c2 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -81,7 +81,7 @@ def post(self, installed_app): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index c039e8bca52a29..f4646920982bdb 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -92,7 +92,7 @@ class ChatApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -140,7 +140,7 @@ class ChatStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 2918024b64cc11..6f9d7769b942ce 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -20,7 +20,7 @@ class ConversationListApi(InstalledAppResource): def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -50,7 +50,7 @@ class ConversationApi(InstalledAppResource): def delete(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -68,7 +68,7 @@ class ConversationRenameApi(InstalledAppResource): def post(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -90,7 +90,7 @@ class ConversationPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -107,7 +107,7 @@ class ConversationUnPinApi(InstalledAppResource): def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 3f1e64a2478b4b..408afc33a0c7c9 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -31,7 +31,7 @@ def get(self): "app_owner_tenant_id": installed_app.app_owner_tenant_id, "is_pinned": installed_app.is_pinned, "last_used_at": installed_app.last_used_at, - "editable": current_user.role in ["owner", "admin"], + "editable": current_user.role in {"owner", "admin"}, "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index f5eb18517258d5..0e0238556cf9aa 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -40,7 +40,7 @@ def get(self, installed_app): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -125,7 +125,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() message_id = str(message_id) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index ad55b040436d47..aab7dd788831b2 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -43,7 +43,7 @@ def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 623f0b8b74dfdb..af3ebc099b1baa 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -194,7 +194,7 @@ def post(self): raise TooManyFilesError() extension = file.filename.split(".")[-1] - if extension.lower() not in ["svg", "png"]: + if extension.lower() not in {"svg", "png"}: raise UnsupportedFileTypeError() try: diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index ecc2d73deb49fb..f7c091217b4401 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -42,7 +42,7 @@ class AppParameterApi(Resource): @marshal_with(parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 8d8ca8d78c0d55..5db41636471220 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -79,7 +79,7 @@ def post(self, app_model: App, end_user: EndUser): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index f1771baf314b47..8d8e356c4cb940 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -96,7 +96,7 @@ class ChatApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -144,7 +144,7 @@ class ChatStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 734027a1c51e5c..527ef4ecd366af 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -18,7 +18,7 @@ class ConversationApi(Resource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -52,7 +52,7 @@ class ConversationDetailApi(Resource): @marshal_with(simple_conversation_fields) def delete(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -69,7 +69,7 @@ class ConversationRenameApi(Resource): @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index b39aaf7dd804d0..e54e6f4903d574 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -76,7 +76,7 @@ class MessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -117,7 +117,7 @@ class MessageSuggestedApi(Resource): def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() try: diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index aabca93338f17c..20b4e4674cc511 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -41,7 +41,7 @@ class AppParameterApi(WebApiResource): @marshal_with(parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise AppUnavailableError() diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 49c467dbe18bb5..23550efe2e2768 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -78,7 +78,7 @@ def post(self, app_model: App, end_user): message_id = args.get("message_id", None) text = args.get("text", None) if ( - app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value} and app_model.workflow and app_model.workflow.features_dict ): diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 0837eedfb068a9..115492b7966c01 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -87,7 +87,7 @@ def post(self, app_model, end_user, task_id): class ChatApi(WebApiResource): def post(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -136,7 +136,7 @@ def post(self, app_model, end_user): class ChatStopApi(WebApiResource): def post(self, app_model, end_user, task_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 6bbfa94c2756a5..c3b0cd4f44b2ac 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -18,7 +18,7 @@ class ConversationListApi(WebApiResource): @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -56,7 +56,7 @@ def get(self, app_model, end_user): class ConversationApi(WebApiResource): def delete(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -73,7 +73,7 @@ class ConversationRenameApi(WebApiResource): @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -92,7 +92,7 @@ def post(self, app_model, end_user, c_id): class ConversationPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) @@ -108,7 +108,7 @@ def patch(self, app_model, end_user, c_id): class ConversationUnPinApi(WebApiResource): def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() conversation_id = str(c_id) diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 56aaaa930a4a87..0d4047f4efbaf8 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -78,7 +78,7 @@ class MessageListApi(WebApiResource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotChatAppError() parser = reqparse.RequestParser() @@ -160,7 +160,7 @@ def get(self, app_model, end_user, message_id): class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) - if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: + if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: raise NotCompletionAppError() message_id = str(message_id) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 1a161677dda345..d04e38777a54aa 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -90,7 +90,7 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, if not in_code_block and not in_json: if delta.lower() == action_str[action_idx] and action_idx == 0: - if last_character not in ["\n", " ", ""]: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue @@ -117,7 +117,7 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, action_idx = 0 if delta.lower() == thought_str[thought_idx] and thought_idx == 0: - if last_character not in ["\n", " ", ""]: + if last_character not in {"\n", " ", ""}: index += steps yield delta continue diff --git a/api/core/app/app_config/base_app_config_manager.py b/api/core/app/app_config/base_app_config_manager.py index 0fd2a779a4800d..24d80f9cdd77f7 100644 --- a/api/core/app/app_config/base_app_config_manager.py +++ b/api/core/app/app_config/base_app_config_manager.py @@ -29,7 +29,7 @@ def convert_features(cls, config_dict: Mapping[str, Any], app_mode: AppMode) -> additional_features.show_retrieve_source = RetrievalResourceConfigManager.convert(config=config_dict) additional_features.file_upload = FileUploadConfigManager.convert( - config=config_dict, is_vision=app_mode in [AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT] + config=config_dict, is_vision=app_mode in {AppMode.CHAT, AppMode.COMPLETION, AppMode.AGENT_CHAT} ) additional_features.opening_statement, additional_features.suggested_questions = ( diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 6e89f19508c01f..f503543d7bd0f5 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -18,7 +18,7 @@ def convert(cls, config: dict) -> Optional[AgentEntity]: if agent_strategy == "function_call": strategy = AgentEntity.Strategy.FUNCTION_CALLING - elif agent_strategy == "cot" or agent_strategy == "react": + elif agent_strategy in {"cot", "react"}: strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT else: # old configs, try to detect default strategy @@ -43,10 +43,10 @@ def convert(cls, config: dict) -> Optional[AgentEntity]: agent_tools.append(AgentToolEntity(**agent_tool_properties)) - if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in [ + if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { "react_router", "router", - ]: + }: agent_prompt = agent_dict.get("prompt", None) or {} # check model mode model_mode = config.get("model", {}).get("mode", "completion") diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index ff131b62e27543..a22395b8e39a03 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -167,7 +167,7 @@ def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mod config["agent_mode"]["strategy"] = PlanningStrategy.ROUTER.value has_datasets = False - if config["agent_mode"]["strategy"] in [PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value]: + if config["agent_mode"]["strategy"] in {PlanningStrategy.ROUTER.value, PlanningStrategy.REACT_ROUTER.value}: for tool in config["agent_mode"]["tools"]: key = list(tool.keys())[0] if key == "dataset": diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index e70522f21de95d..a1bfde32085ae3 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -42,12 +42,12 @@ def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataV variable=variable["variable"], type=variable["type"], config=variable["config"] ) ) - elif variable_type in [ + elif variable_type in { VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH, VariableEntityType.NUMBER, VariableEntityType.SELECT, - ]: + }: variable = variables[variable_type] variable_entities.append( VariableEntity( @@ -97,7 +97,7 @@ def validate_variables_and_set_defaults(cls, config: dict) -> tuple[dict, list[s variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: + if key not in {"text-input", "select", "paragraph", "number", "external_data_tool"}: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 5f7fc99151d2e8..7a275cb532f6d7 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -54,14 +54,14 @@ def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tupl if is_vision: detail = config["file_upload"]["image"]["detail"] - if detail not in ["high", "low"]: + if detail not in {"high", "low"}: raise ValueError("detail must be in ['high', 'low']") transfer_methods = config["file_upload"]["image"]["transfer_methods"] if not isinstance(transfer_methods, list): raise ValueError("transfer_methods must be of list type") for method in transfer_methods: - if method not in ["remote_url", "local_file"]: + if method not in {"remote_url", "local_file"}: raise ValueError("transfer_methods must be in ['remote_url', 'local_file']") return config, ["file_upload"] diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c4cdba64419234..1bca1e1b71adce 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -73,7 +73,7 @@ def run(self) -> None: raise ValueError("Workflow not initialized") user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -175,7 +175,7 @@ def run(self) -> None: user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 73025d99d056ab..c6855ac85494d6 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -16,7 +16,7 @@ class AppGenerateResponseConverter(ABC): def convert( cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom ) -> dict[str, Any] | Generator[str, Any, None]: - if invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]: + if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) else: diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index ce6f7d43387599..15be7000fc28bf 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -22,11 +22,11 @@ def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): return var.default or "" if ( var.type - in ( + in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, VariableEntityType.PARAGRAPH, - ) + } and user_input_value and not isinstance(user_input_value, str) ): @@ -44,7 +44,7 @@ def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity): options = var.options or [] if user_input_value not in options: raise ValueError(f"{var.variable} in input form must be one of the following: {options}") - elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH): + elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}: if var.max_length and user_input_value and len(user_input_value) > var.max_length: raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index f3c3199354f198..4c4d282e99b6ae 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -32,7 +32,7 @@ def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom) -> None: self._user_id = user_id self._invoke_from = invoke_from - user_prefix = "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + user_prefix = "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" redis_client.setex( AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) @@ -118,7 +118,7 @@ def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> N if result is None: return - user_prefix = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end-user" + user_prefix = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end-user" if result.decode("utf-8") != f"{user_prefix}-{user_id}": return diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index f629c5c8b73a29..c4db95cbd0c4a6 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -148,7 +148,7 @@ def _init_generate_records( # get from source end_user_id = None account_id = None - if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: from_source = "api" end_user_id = application_generate_entity.user_id else: @@ -165,11 +165,11 @@ def _init_generate_records( model_provider = application_generate_entity.model_conf.provider model_id = application_generate_entity.model_conf.model override_model_configs = None - if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in [ + if app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS and app_config.app_mode in { AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION, - ]: + }: override_model_configs = app_config.app_model_config_dict # get conversation introduction diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 81c8463dd54a7b..22ec228fa79b11 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -53,7 +53,7 @@ def run(self) -> None: app_config = cast(WorkflowAppConfig, app_config) user_id = None - if self.application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]: + if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id @@ -113,7 +113,7 @@ def run(self) -> None: user_id=self.application_generate_entity.user_id, user_from=( UserFrom.ACCOUNT - if self.application_generate_entity.invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] + if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else UserFrom.END_USER ), invoke_from=self.application_generate_entity.invoke_from, diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 2e37a126c3eb31..77b6bb554c65ec 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -63,7 +63,7 @@ def query( score = documents[0].metadata["score"] annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: - if invoke_from in [InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP]: + if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: from_source = "api" else: from_source = "console" diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 659503301e59ca..8f834b6458ea4b 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -372,7 +372,7 @@ def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> No self._message, application_generate_entity=self._application_generate_entity, conversation=self._conversation, - is_first_message=self._application_generate_entity.app_config.app_mode in [AppMode.AGENT_CHAT, AppMode.CHAT] + is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT} and self._application_generate_entity.conversation_id is None, extras=self._application_generate_entity.extras, ) diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index a030d5dcbf3c05..f10189798f85fc 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -383,7 +383,7 @@ def _workflow_node_start_to_stream_response( :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None response = NodeStartStreamResponse( @@ -430,7 +430,7 @@ def _workflow_node_finish_to_stream_response( :param workflow_node_execution: workflow node execution :return: """ - if workflow_node_execution.node_type in [NodeType.ITERATION.value, NodeType.LOOP.value]: + if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: return None return NodeFinishStreamResponse( diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 6d5393ce5c692d..7cf472d984a38c 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -29,7 +29,7 @@ def on_query(self, query: str, dataset_id: str) -> None: source="app", source_app_id=self._app_id, created_by_role=( - "account" if self._invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" ), created_by=self._user_id, ) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index eeb1dbfda0b77e..af20df41b1f286 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -292,7 +292,7 @@ def _extract( self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict ) -> list[Document]: # load file - if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]: + if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}: return [] data_source_info = dataset_document.data_source_info_dict diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index a14d237a12a393..d3185c3b11aecb 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -52,7 +52,7 @@ def get_history_prompt_messages( files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if self.conversation.mode not in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) else: if message.workflow_run_id: diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index d898ef149092f6..52ea787c3ad572 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -27,17 +27,17 @@ def value_of(cls, origin_model_type: str) -> "ModelType": :return: model type """ - if origin_model_type == "text-generation" or origin_model_type == cls.LLM.value: + if origin_model_type in {"text-generation", cls.LLM.value}: return cls.LLM - elif origin_model_type == "embeddings" or origin_model_type == cls.TEXT_EMBEDDING.value: + elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}: return cls.TEXT_EMBEDDING - elif origin_model_type == "reranking" or origin_model_type == cls.RERANK.value: + elif origin_model_type in {"reranking", cls.RERANK.value}: return cls.RERANK - elif origin_model_type == "speech2text" or origin_model_type == cls.SPEECH2TEXT.value: + elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}: return cls.SPEECH2TEXT - elif origin_model_type == "tts" or origin_model_type == cls.TTS.value: + elif origin_model_type in {"tts", cls.TTS.value}: return cls.TTS - elif origin_model_type == "text2img" or origin_model_type == cls.TEXT2IMG.value: + elif origin_model_type in {"text2img", cls.TEXT2IMG.value}: return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index ff741e02402b1f..46e1b415b81253 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -494,7 +494,7 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 8db044b24dd278..af178703a06951 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -85,14 +85,14 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str for i in range(len(sentences)) ] for future in futures: - yield from future.result().__enter__().iter_bytes(1024) + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: response = client.audio.speech.with_streaming_response.create( model=model, voice=voice, response_format="mp3", input=content_text.strip() ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index c34c20ced35a2a..06a860690193a9 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -454,7 +454,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: base64_data = data_split[1] image_content = base64.b64decode(base64_data) - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" @@ -886,16 +886,16 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in [ + elif error_code in { "ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException", - ]: + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 2d898e3aaacfec..251170d1aec492 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -186,16 +186,16 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I if error_code == "AccessDeniedException": return InvokeAuthorizationError(error_msg) - elif error_code in ["ResourceNotFoundException", "ValidationException"]: + elif error_code in {"ResourceNotFoundException", "ValidationException"}: return InvokeBadRequestError(error_msg) - elif error_code in ["ThrottlingException", "ServiceQuotaExceededException"]: + elif error_code in {"ThrottlingException", "ServiceQuotaExceededException"}: return InvokeRateLimitError(error_msg) - elif error_code in [ + elif error_code in { "ModelTimeoutException", "ModelErrorException", "InternalServerException", "ModelNotReadyException", - ]: + }: return InvokeServerUnavailableError(error_msg) elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b10d0edba355eb..3fc6787a444e41 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -6,10 +6,10 @@ from typing import Optional, Union, cast import google.ai.generativelanguage as glm -import google.api_core.exceptions as exceptions import google.generativeai as genai -import google.generativeai.client as client import requests +from google.api_core import exceptions +from google.generativeai import client from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part from PIL import Image diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 48ab477c50ca8c..9d29237fdde573 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -77,7 +77,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: if "huggingfacehub_api_type" not in credentials: raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type must be provided.") - if credentials["huggingfacehub_api_type"] not in ("inference_endpoints", "hosted_inference_api"): + if credentials["huggingfacehub_api_type"] not in {"inference_endpoints", "hosted_inference_api"}: raise CredentialsValidateFailedError("Huggingface Hub Endpoint Type is invalid.") if "huggingfacehub_api_token" not in credentials: @@ -94,7 +94,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: credentials["huggingfacehub_api_token"], model ) - if credentials["task_type"] not in ("text2text-generation", "text-generation"): + if credentials["task_type"] not in {"text2text-generation", "text-generation"}: raise CredentialsValidateFailedError( "Huggingface Hub Task Type must be one of text2text-generation, text-generation." ) diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 288637495f6c87..81ab2492144e86 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -75,7 +75,7 @@ def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: if len(model_type.keys()) < 1: raise RuntimeError("model_type is empty") model_type = list(model_type.keys())[0] - if model_type not in ["embedding", "reranker"]: + if model_type not in {"embedding", "reranker"}: raise RuntimeError(f"invalid model_type: {model_type}") max_input_length = response_json.get("max_input_length", 512) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 96f99c892978e2..88cc0e8e0f32d0 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -100,9 +100,9 @@ def generate( return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 0a2a67a56d78cd..8b8fdbb6bdf558 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -105,9 +105,9 @@ def generate( return self._handle_chat_generate_response(response) def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001 or code == 1013 or code == 1027: + if code in {1000, 1001, 1013, 1027}: raise InternalServerError(msg) - elif code == 1002 or code == 1039: + elif code in {1002, 1039}: raise RateLimitReachedError(msg) elif code == 1004: raise InvalidAuthenticationError(msg) diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 02a53708be9c85..76fd1342bdb929 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -114,7 +114,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: raise CredentialsValidateFailedError("Invalid api key") def _handle_error(self, code: int, msg: str): - if code == 1000 or code == 1001: + if code in {1000, 1001}: raise InternalServerError(msg) elif code == 1002: raise RateLimitReachedError(msg) diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 60d69c6e472156..d42fce528a8f30 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -125,7 +125,7 @@ def _code_block_mode_wrapper( model_mode = self.get_model_mode(base_model, credentials) # transform response format - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] if model_mode == LLMMode.CHAT: # chat model diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index b50b43199feee5..a14c91639b88fc 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -89,14 +89,14 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str for i in range(len(sentences)) ] for future in futures: - yield from future.result().__enter__().iter_bytes(1024) + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 else: response = client.audio.speech.with_streaming_response.create( model=model, voice=voice, response_format="mp3", input=content_text.strip() ) - yield from response.__enter__().iter_bytes(1024) + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 except Exception as ex: raise InvokeBadRequestError(str(ex)) diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index 71b5745f7d1fba..b6bb249a0407bc 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -12,7 +12,6 @@ def _update_credential(self, model: str, credentials: dict): credentials["endpoint_url"] = "https://openrouter.ai/api/v1" credentials["mode"] = self.get_model_mode(model).value credentials["function_calling_type"] = "tool_call" - return def _invoke( self, diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index daef8949fb87a4..3641b35dc02a39 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -154,7 +154,7 @@ def _get_customizable_model_parameter_rules(cls, model: str, credentials: dict) ) for key, value in input_properties: - if key not in ["system_prompt", "prompt"] and "stop" not in key: + if key not in {"system_prompt", "prompt"} and "stop" not in key: value_type = value.get("type") if not value_type: diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index f6b7754d74f4c1..71b6fb99c4a1b2 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -86,7 +86,7 @@ def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) ) for input_property in input_properties: - if input_property[0] in ("text", "texts", "inputs"): + if input_property[0] in {"text", "texts", "inputs"}: text_input_key = input_property[0] return text_input_key @@ -96,7 +96,7 @@ def _get_text_input_key(model: str, model_version: str, client: ReplicateClient) def _generate_embeddings_by_text_input_key( client: ReplicateClient, replicate_model_version: str, text_input_key: str, texts: list[str] ) -> list[list[float]]: - if text_input_key in ("text", "inputs"): + if text_input_key in {"text", "inputs"}: embeddings = [] for text in texts: result = client.run(replicate_model_version, input={text_input_key: text}) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index cd7718361ff032..1d4eba6668c0ce 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -89,7 +89,7 @@ def get_num_tokens( :param tools: tools for tool calling :return: """ - if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: model = model.replace("-chat", "") if model == "farui-plus": model = "qwen-farui-plus" @@ -157,7 +157,7 @@ def _generate( mode = self.get_model_mode(model, credentials) - if model in ["qwen-turbo-chat", "qwen-plus-chat"]: + if model in {"qwen-turbo-chat", "qwen-plus-chat"}: model = model.replace("-chat", "") extra_model_kwargs = {} @@ -201,7 +201,7 @@ def _handle_generate_response( :param prompt_messages: prompt messages :return: llm response """ - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError(response.message) # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( @@ -240,7 +240,7 @@ def _handle_generate_stream_response( full_text = "" tool_calls = [] for index, response in enumerate(responses): - if response.status_code != 200 and response.status_code != HTTPStatus.OK: + if response.status_code not in {200, HTTPStatus.OK}: raise ServiceUnavailableError( f"Failed to invoke model {model}, status code: {response.status_code}, " f"message: {response.message}" diff --git a/api/core/model_runtime/model_providers/upstage/llm/llm.py b/api/core/model_runtime/model_providers/upstage/llm/llm.py index 74524e81e281fb..a18ee906248a49 100644 --- a/api/core/model_runtime/model_providers/upstage/llm/llm.py +++ b/api/core/model_runtime/model_providers/upstage/llm/llm.py @@ -93,7 +93,7 @@ def _code_block_mode_wrapper( """ Code block mode wrapper for invoking large language model """ - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: stop = stop or [] self._transform_chat_json_prompts( model=model, diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index dad5002f357ff3..da69b7cdf382de 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -5,7 +5,6 @@ from collections.abc import Generator from typing import Optional, Union, cast -import google.api_core.exceptions as exceptions import google.auth.transport.requests import vertexai.generative_models as glm from anthropic import AnthropicVertex, Stream @@ -17,6 +16,7 @@ MessageStopEvent, MessageStreamEvent, ) +from google.api_core import exceptions from google.cloud import aiplatform from google.oauth2 import service_account from PIL import Image @@ -346,7 +346,7 @@ def _convert_claude_prompt_message_to_dict(self, message: PromptMessage) -> dict mime_type = data_split[0].replace("data:", "") base64_data = data_split[1] - if mime_type not in ["image/jpeg", "image/png", "image/gif", "image/webp"]: + if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( f"Unsupported image type {mime_type}, " f"only support image/jpeg, image/png, image/gif, and image/webp" diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py index 97c77de8d32c33..c22bf8e76de36a 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/base/auth.py @@ -96,7 +96,6 @@ def sign(request, credentials): signing_key = Signer.get_signing_secret_key_v4(credentials.sk, md.date, md.region, md.service) sign = Util.to_hex(Util.hmac_sha256(signing_key, signing_str)) request.headers["Authorization"] = Signer.build_auth_header_v4(sign, md, credentials) - return @staticmethod def hashed_canonical_request_v4(request, meta): @@ -105,7 +104,7 @@ def hashed_canonical_request_v4(request, meta): signed_headers = {} for key in request.headers: - if key in ["Content-Type", "Content-Md5", "Host"] or key.startswith("X-"): + if key in {"Content-Type", "Content-Md5", "Host"} or key.startswith("X-"): signed_headers[key.lower()] = request.headers[key] if "host" in signed_headers: diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index ec3556f7da8a44..f7c160b6b47b79 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -69,7 +69,7 @@ def _code_block_mode_wrapper( """ Code block mode wrapper for invoking large language model """ - if "response_format" in model_parameters and model_parameters["response_format"] in ["JSON", "XML"]: + if "response_format" in model_parameters and model_parameters["response_format"] in {"JSON", "XML"}: response_format = model_parameters["response_format"] stop = stop or [] self._transform_json_prompts( diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 1e05da9c56b27c..619ee1492a9272 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -103,7 +103,7 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: st model_handle_type = "embedding" elif response_json.get("model_type") == "audio": model_handle_type = "audio" - if model_family and model_family in ["ChatTTS", "CosyVoice", "FishAudio"]: + if model_family and model_family in {"ChatTTS", "CosyVoice", "FishAudio"}: model_ability.append("text-to-audio") else: model_ability.append("audio-to-text") diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index f76e51fee9e43b..ea331701abb78a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -186,10 +186,10 @@ def _generate( new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() - if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' - if model not in ("glm-4v", "glm-4v-plus"): + if model not in {"glm-4v", "glm-4v-plus"}: # not support list message continue # get image and @@ -209,10 +209,7 @@ def _generate( ): new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: - if ( - copy_prompt_message.role == PromptMessageRole.USER - or copy_prompt_message.role == PromptMessageRole.TOOL - ): + if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.TOOL}: new_prompt_messages.append(copy_prompt_message) elif copy_prompt_message.role == PromptMessageRole.SYSTEM: new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) @@ -226,7 +223,7 @@ def _generate( else: new_prompt_messages.append(copy_prompt_message) - if model == "glm-4v" or model == "glm-4v-plus": + if model in {"glm-4v", "glm-4v-plus"}: params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = {"model": model, "messages": [], **model_parameters} @@ -270,11 +267,11 @@ def _generate( # chatglm model for prompt_message in new_prompt_messages: # merge system message to user message - if ( - prompt_message.role == PromptMessageRole.SYSTEM - or prompt_message.role == PromptMessageRole.TOOL - or prompt_message.role == PromptMessageRole.USER - ): + if prompt_message.role in { + PromptMessageRole.SYSTEM, + PromptMessageRole.TOOL, + PromptMessageRole.USER, + }: if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": params["messages"][-1]["content"] += "\n\n" + prompt_message.content else: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py index af0991892e084f..416f516ef7bf1c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py @@ -1,5 +1,4 @@ from __future__ import annotations -from .fine_tuning_job import FineTuningJob as FineTuningJob -from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob -from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent +from .fine_tuning_job import FineTuningJob, ListOfFineTuningJob +from .fine_tuning_job_event import FineTuningJobEvent diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/core/model_runtime/schema_validators/common_validator.py index c05edb72e3552e..029ec1a581b2e9 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/core/model_runtime/schema_validators/common_validator.py @@ -75,7 +75,7 @@ def _validate_credential_form_schema( if not isinstance(value, str): raise ValueError(f"Variable {credential_form_schema.variable} should be string") - if credential_form_schema.type in [FormType.SELECT, FormType.RADIO]: + if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: # If the value is in options, no validation is performed if credential_form_schema.options: if value not in [option.value for option in credential_form_schema.options]: @@ -83,7 +83,7 @@ def _validate_credential_form_schema( if credential_form_schema.type == FormType.SWITCH: # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in ["true", "false"]: + if value.lower() not in {"true", "false"}: raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") value = True if value.lower() == "true" else False diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f8300cc2715a6c..8d578551209e09 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -51,7 +51,7 @@ def __init__(self, index_name: str, config: ElasticSearchConfig, attributes: lis def _init_client(self, config: ElasticSearchConfig) -> Elasticsearch: try: parsed_url = urlparse(config.host) - if parsed_url.scheme in ["http", "https"]: + if parsed_url.scheme in {"http", "https"}: hosts = f"{config.host}:{config.port}" else: hosts = f"http://{config.host}:{config.port}" @@ -94,7 +94,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** return uuids def text_exists(self, id: str) -> bool: - return self._client.exists(index=self._collection_name, id=id).__bool__() + return bool(self._client.exists(index=self._collection_name, id=id)) def delete_by_ids(self, ids: list[str]) -> None: for id in ids: diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index 1bd5bcd3e456a7..2320a69a30ad11 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -35,7 +35,7 @@ def __init__(self, collection_name: str, config: MyScaleConfig, metric: str = "C super().__init__(collection_name) self._config = config self._metric = metric - self._vec_order = SortOrder.ASC if metric.upper() in ["COSINE", "L2"] else SortOrder.DESC + self._vec_order = SortOrder.ASC if metric.upper() in {"COSINE", "L2"} else SortOrder.DESC self._client = get_client( host=config.host, port=config.port, @@ -92,7 +92,7 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], ** @staticmethod def escape_str(value: Any) -> str: - return "".join(" " if c in ("\\", "'") else c for c in str(value)) + return "".join(" " if c in {"\\", "'"} else c for c in str(value)) def text_exists(self, id: str) -> bool: results = self._client.query(f"SELECT id FROM {self._config.database}.{self._collection_name} WHERE id='{id}'") diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index b974fa80a4a216..77ec45b4d39507 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -223,15 +223,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: words = pseg.cut(query) current_entity = "" for word, pos in words: - if ( - pos == "nr" - or pos == "Ng" - or pos == "eng" - or pos == "nz" - or pos == "n" - or pos == "ORG" - or pos == "v" - ): # nr: 人名, ns: 地名, nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 current_entity += word else: if current_entity: diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 3181656f592871..fe7eaa32e62cb6 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -98,17 +98,17 @@ def extract( unstructured_api_url = dify_config.UNSTRUCTURED_API_URL unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY if etl_type == "Unstructured": - if file_extension == ".xlsx" or file_extension == ".xls": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in [".md", ".markdown"]: + elif file_extension in {".md", ".markdown"}: extractor = ( UnstructuredMarkdownExtractor(file_path, unstructured_api_url) if is_automatic else MarkdownExtractor(file_path, autodetect_encoding=True) ) - elif file_extension in [".htm", ".html"]: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) @@ -134,13 +134,13 @@ def extract( else TextExtractor(file_path, autodetect_encoding=True) ) else: - if file_extension == ".xlsx" or file_extension == ".xls": + if file_extension in {".xlsx", ".xls"}: extractor = ExcelExtractor(file_path) elif file_extension == ".pdf": extractor = PdfExtractor(file_path) - elif file_extension in [".md", ".markdown"]: + elif file_extension in {".md", ".markdown"}: extractor = MarkdownExtractor(file_path, autodetect_encoding=True) - elif file_extension in [".htm", ".html"]: + elif file_extension in {".htm", ".html"}: extractor = HtmlExtractor(file_path) elif file_extension == ".docx": extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by) diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 054ce5f4b2e6c6..17c2087a0ab575 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -32,7 +32,7 @@ def scrape_url(self, url, params=None) -> dict: else: raise Exception(f'Failed to scrape URL. Error: {response["error"]}') - elif response.status_code in [402, 409, 500]: + elif response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 0ee24983a42649..87a4ce08bf3f89 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -103,12 +103,12 @@ def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] multi_select_list = property_value[type] for multi_select in multi_select_list: value.append(multi_select["name"]) - elif type == "rich_text" or type == "title": + elif type in {"rich_text", "title"}: if len(property_value[type]) > 0: value = property_value[type][0]["plain_text"] else: value = "" - elif type == "select" or type == "status": + elif type in {"select", "status"}: if property_value[type]: value = property_value[type]["name"] else: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 12868d6ae43945..124c58f0fe2c4f 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -115,7 +115,7 @@ def retrieve( available_datasets.append(dataset) all_documents = [] - user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user" + user_from = "account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: all_documents = self.single_retrieve( app_id, diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 161c36607da91c..7dd62f8de18a15 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -35,7 +35,7 @@ def _split_text_with_regex(text: str, separator: str, keep_separator: bool) -> l splits = re.split(separator, text) else: splits = list(text) - return [s for s in splits if (s != "" and s != "\n")] + return [s for s in splits if (s not in {"", "\n"})] class TextSplitter(BaseDocumentTransformer, ABC): diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 01544d7e562d52..09f328cd1fe65f 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -68,7 +68,7 @@ def get_tools(self, user_id: str) -> list[Tool]: label = input_form[form_type]["label"] variable_name = input_form[form_type]["variable_name"] options = input_form[form_type].get("options", []) - if form_type == "paragraph" or form_type == "text-input": + if form_type in {"paragraph", "text-input"}: tool["parameters"].append( ToolParameter( name=variable_name, diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index a2d69fbcd1a5ed..dd9371f70d63f5 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -168,7 +168,7 @@ def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: pass elif event == "close": break - elif event == "error" or event == "filter": + elif event in {"error", "filter"}: raise Exception(f"Failed to generate outline: {data}") return outline @@ -213,7 +213,7 @@ def _generate_content(self, task_id: str, model: str, user_id: str) -> str: pass elif event == "close": break - elif event == "error" or event == "filter": + elif event in {"error", "filter"}: raise Exception(f"Failed to generate content: {data}") return content diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 7462824be1ea55..cfa3cfb092803a 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -39,11 +39,11 @@ def _invoke( n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # set extra body seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) diff --git a/api/core/tools/provider/builtin/code/tools/simple_code.py b/api/core/tools/provider/builtin/code/tools/simple_code.py index 017fe548f76e68..632c9fc7f1451b 100644 --- a/api/core/tools/provider/builtin/code/tools/simple_code.py +++ b/api/core/tools/provider/builtin/code/tools/simple_code.py @@ -14,7 +14,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe language = tool_parameters.get("language", CodeLanguage.PYTHON3) code = tool_parameters.get("code", "") - if language not in [CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT]: + if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}: raise ValueError(f"Only python3 and javascript are supported, not {language}") result = CodeExecutor.execute_code(language, "", code) diff --git a/api/core/tools/provider/builtin/cogview/tools/cogview3.py b/api/core/tools/provider/builtin/cogview/tools/cogview3.py index 9776bd7dd1c02d..9039708588df16 100644 --- a/api/core/tools/provider/builtin/cogview/tools/cogview3.py +++ b/api/core/tools/provider/builtin/cogview/tools/cogview3.py @@ -34,11 +34,11 @@ def _invoke( n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # set extra body seed_id = tool_parameters.get("seed_id", self._generate_random_id(8)) diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index bcfa2212b6caec..a8c647d71e69e0 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -49,11 +49,11 @@ def _invoke( n = tool_parameters.get("n", 1) # get quality quality = tool_parameters.get("quality", "standard") - if quality not in ["standard", "hd"]: + if quality not in {"standard", "hd"}: return self.create_text_message("Invalid quality") # get style style = tool_parameters.get("style", "vivid") - if style not in ["natural", "vivid"]: + if style not in {"natural", "vivid"}: return self.create_text_message("Invalid style") # call openapi dalle3 diff --git a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py index 40e1af043b352e..79e5889eaef3b5 100644 --- a/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py +++ b/api/core/tools/provider/builtin/hap/tools/get_worksheet_fields.py @@ -133,9 +133,9 @@ def get_controls(self, controls: list) -> dict: def _extract_options(self, control: dict) -> list: options = [] - if control["type"] in [9, 10, 11]: + if control["type"] in {9, 10, 11}: options.extend([{"key": opt["key"], "value": opt["value"]} for opt in control.get("options", [])]) - elif control["type"] in [28, 36]: + elif control["type"] in {28, 36}: itemnames = control["advancedSetting"].get("itemnames") if itemnames and itemnames.startswith("[{"): try: diff --git a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py index 171895a3060fb1..44c7e523070d0e 100644 --- a/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py +++ b/api/core/tools/provider/builtin/hap/tools/list_worksheet_records.py @@ -183,11 +183,11 @@ def handle_value_type(self, value, field): type_id = field.get("typeId") if type_id == 10: value = value if isinstance(value, str) else "、".join(value) - elif type_id in [28, 36]: + elif type_id in {28, 36}: value = field.get("options", {}).get(value, value) - elif type_id in [26, 27, 48, 14]: + elif type_id in {26, 27, 48, 14}: value = self.process_value(value) - elif type_id in [35, 29]: + elif type_id in {35, 29}: value = self.parse_cascade_or_associated(field, value) elif type_id == 40: value = self.parse_location(value) diff --git a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py index 9ca14b327caaea..a200ee81231f00 100644 --- a/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py +++ b/api/core/tools/provider/builtin/novitaai/tools/novitaai_modelquery.py @@ -35,7 +35,7 @@ def _invoke( models_data=[], headers=headers, params=params, - recursive=not (result_type == "first sd_name" or result_type == "first name sd_name pair"), + recursive=result_type not in {"first sd_name", "first name sd_name pair"}, ) result_str = "" diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index 16ae14549d21d2..17e2978194c6a3 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -38,7 +38,7 @@ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: return { "engine": "google", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index d29cb0ae3f1cbe..c478bc108b47e1 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -38,7 +38,7 @@ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: return { "engine": "google_jobs", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 8458c8c958e9ea..562bc01964b4c3 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -38,7 +38,7 @@ def get_params(self, query: str, **kwargs: Any) -> dict[str, str]: return { "engine": "google_news", "q": query, - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 09725cf8a2ab00..1867cf7be79be5 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -38,7 +38,7 @@ def get_params(self, video_id: str, language: str, **kwargs: Any) -> dict[str, s "engine": "youtube_transcripts", "video_id": video_id, "lang": language or "en", - **{key: value for key, value in kwargs.items() if value not in [None, ""]}, + **{key: value for key, value in kwargs.items() if value not in {None, ""}}, } @staticmethod diff --git a/api/core/tools/provider/builtin/spider/spiderApp.py b/api/core/tools/provider/builtin/spider/spiderApp.py index 3972e560c41b3d..4bc446a1a092a3 100644 --- a/api/core/tools/provider/builtin/spider/spiderApp.py +++ b/api/core/tools/provider/builtin/spider/spiderApp.py @@ -214,7 +214,7 @@ def _delete_request(self, url: str, headers, stream=False): return requests.delete(url, headers=headers, stream=stream) def _handle_error(self, response, action): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/core/tools/provider/builtin/stability/tools/text2image.py b/api/core/tools/provider/builtin/stability/tools/text2image.py index 9f415ceb5509ba..6bcf315484ad50 100644 --- a/api/core/tools/provider/builtin/stability/tools/text2image.py +++ b/api/core/tools/provider/builtin/stability/tools/text2image.py @@ -32,7 +32,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe model = tool_parameters.get("model", "core") - if model in ["sd3", "sd3-turbo"]: + if model in {"sd3", "sd3-turbo"}: payload["model"] = tool_parameters.get("model") if model != "sd3-turbo": diff --git a/api/core/tools/provider/builtin/vanna/tools/vanna.py b/api/core/tools/provider/builtin/vanna/tools/vanna.py index a6efb0f79a5d93..c90d766e483aaa 100644 --- a/api/core/tools/provider/builtin/vanna/tools/vanna.py +++ b/api/core/tools/provider/builtin/vanna/tools/vanna.py @@ -38,7 +38,7 @@ def _invoke( vn = VannaDefault(model=model, api_key=api_key) db_type = tool_parameters.get("db_type", "") - if db_type in ["Postgres", "MySQL", "Hive", "ClickHouse"]: + if db_type in {"Postgres", "MySQL", "Hive", "ClickHouse"}: if not db_name: return self.create_text_message("Please input database name") if not username: diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 6b64dd1b4e3201..ff022812ef54d5 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -19,7 +19,7 @@ class BuiltinToolProviderController(ToolProviderController): def __init__(self, **data: Any) -> None: - if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP: + if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}: super().__init__(**data) return diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index f4008eedce8154..7ba9dda17991c9 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -153,10 +153,10 @@ def validate_credentials_format(self, credentials: dict[str, Any]) -> None: # check type credential_schema = credentials_need_to_validate[credential_name] - if ( - credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT - or credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT - ): + if credential_schema in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + }: if not isinstance(credentials[credential_name], str): raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") @@ -184,11 +184,11 @@ def validate_credentials_format(self, credentials: dict[str, Any]) -> None: if credential_schema.default is not None: default_value = credential_schema.default # parse default value into the correct type - if ( - credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT - or credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT - or credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT - ): + if credential_schema.type in { + ToolProviderCredentials.CredentialsType.SECRET_INPUT, + ToolProviderCredentials.CredentialsType.TEXT_INPUT, + ToolProviderCredentials.CredentialsType.SELECT, + }: default_value = str(default_value) credentials[credential_name] = default_value diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index bf336b48f304e8..c779d704c368e5 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -5,7 +5,7 @@ import httpx -import core.helper.ssrf_proxy as ssrf_proxy +from core.helper import ssrf_proxy from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError @@ -191,7 +191,7 @@ def do_http_request( else: body = body - if method in ("get", "head", "post", "put", "delete", "patch"): + if method in {"get", "head", "post", "put", "delete", "patch"}: response = getattr(ssrf_proxy, method)( url, params=params, @@ -224,9 +224,9 @@ def _convert_body_property_any_of( elif option["type"] == "string": return str(value) elif option["type"] == "boolean": - if str(value).lower() in ["true", "1"]: + if str(value).lower() in {"true", "1"}: return True - elif str(value).lower() in ["false", "0"]: + elif str(value).lower() in {"false", "0"}: return False else: continue # Not a boolean, try next option diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 645f0861fa06b3..9912114dd6a95f 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -189,10 +189,7 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str result += response.message elif response.type == ToolInvokeMessage.MessageType.LINK: result += f"result link: {response.message}. please tell user to check it." - elif ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, you do not need to create it," " just tell the user to check it now." @@ -212,10 +209,7 @@ def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> lis result = [] for response in tool_response: - if ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: mimetype = None if response.meta.get("mime_type"): mimetype = response.meta.get("mime_type") @@ -297,7 +291,7 @@ def _create_message_files( belongs_to="assistant", url=message.url, upload_file_id=None, - created_by_role=("account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"), + created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"), created_by=user_id, ) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index bf040d91d3987e..3cfab207ba73e4 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -19,7 +19,7 @@ def transform_tool_invoke_messages( result = [] for message in messages: - if message.type == ToolInvokeMessage.MessageType.TEXT or message.type == ToolInvokeMessage.MessageType.LINK: + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}: result.append(message) elif message.type == ToolInvokeMessage.MessageType.IMAGE: # try to download image diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 210b84b29a72f9..9ead4f8e5cf471 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -165,7 +165,7 @@ def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType elif "schema" in parameter and "type" in parameter["schema"]: typ = parameter["schema"]["type"] - if typ == "integer" or typ == "number": + if typ in {"integer", "number"}: return ToolParameter.ToolParameterType.NUMBER elif typ == "boolean": return ToolParameter.ToolParameterType.BOOLEAN diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index e57cae9f16ac80..1ced7d0488e3f2 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -313,7 +313,7 @@ def normalize_whitespace(text): def is_leaf(element): - return element.name in ["p", "li"] + return element.name in {"p", "li"} def is_text(element): diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 8fc80474265d9f..bb24b511127395 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -51,7 +51,7 @@ def set_finished(self, run_result: NodeRunResult) -> None: :param run_result: run result """ - if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + if self.status in {RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED}: raise Exception(f"Route state {self.id} already finished") if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index e31a1479a8f41c..5e6de8fb153060 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -148,11 +148,11 @@ def _recursive_fetch_answer_dependencies( for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in ( + if source_node_type in { NodeType.ANSWER.value, NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER.value, - ): + }: answer_dependencies[answer_node_id].append(source_node_id) else: cls._recursive_fetch_answer_dependencies( diff --git a/api/core/workflow/nodes/end/end_stream_generate_router.py b/api/core/workflow/nodes/end/end_stream_generate_router.py index a38d98239385de..9a7d2ecde3b90a 100644 --- a/api/core/workflow/nodes/end/end_stream_generate_router.py +++ b/api/core/workflow/nodes/end/end_stream_generate_router.py @@ -136,10 +136,10 @@ def _recursive_fetch_end_dependencies( for edge in reverse_edges: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") - if source_node_type in ( + if source_node_type in { NodeType.IF_ELSE.value, NodeType.QUESTION_CLASSIFIER, - ): + }: end_dependencies[end_node_id].append(source_node_id) else: cls._recursive_fetch_end_dependencies( diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 49102dc3ab127a..f8ab4e313241f3 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -6,8 +6,8 @@ import httpx -import core.helper.ssrf_proxy as ssrf_proxy from configs import dify_config +from core.helper import ssrf_proxy from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.http_request.entities import ( @@ -176,7 +176,7 @@ def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional elif node_data.body.type == "x-www-form-urlencoded" and not content_type_is_set: self.headers["Content-Type"] = "application/x-www-form-urlencoded" - if node_data.body.type in ["form-data", "x-www-form-urlencoded"]: + if node_data.body.type in {"form-data", "x-www-form-urlencoded"}: body = self._to_dict(body_data) if node_data.body.type == "form-data": @@ -187,7 +187,7 @@ def _init_template(self, node_data: HttpRequestNodeData, variable_pool: Optional self.headers["Content-Type"] = f"multipart/form-data; boundary={self.boundary}" else: self.body = urlencode(body) - elif node_data.body.type in ["json", "raw-text"]: + elif node_data.body.type in {"json", "raw-text"}: self.body = body_data elif node_data.body.type == "none": self.body = "" @@ -258,7 +258,7 @@ def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: "follow_redirects": True, } - if self.method in ("get", "head", "post", "put", "delete", "patch"): + if self.method in {"get", "head", "post", "put", "delete", "patch"}: response = getattr(ssrf_proxy, self.method)(data=self.body, files=self.files, **kwargs) else: raise ValueError(f"Invalid http method {self.method}") diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py index 802ed31e27a42d..5697d7c04983fe 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/core/workflow/nodes/parameter_extractor/entities.py @@ -33,7 +33,7 @@ class ParameterConfig(BaseModel): def validate_name(cls, value) -> str: if not value: raise ValueError("Parameter name is required") - if value in ["__reason", "__is_success"]: + if value in {"__reason", "__is_success"}: raise ValueError("Invalid parameter name, __reason and __is_success are reserved") return value @@ -66,7 +66,7 @@ def get_parameter_json_schema(self) -> dict: for parameter in self.parameters: parameter_schema = {"description": parameter.description} - if parameter.type in ["string", "select"]: + if parameter.type in {"string", "select"}: parameter_schema["type"] = "string" elif parameter.type.startswith("array"): parameter_schema["type"] = "array" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 131d26b19eaf31..a6454bd1cd28ba 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -467,7 +467,7 @@ def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> d # transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true') # elif isinstance(result[parameter.name], int): # transformed_result[parameter.name] = bool(result[parameter.name]) - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: if isinstance(result[parameter.name], str): transformed_result[parameter.name] = result[parameter.name] elif parameter.type.startswith("array"): @@ -498,7 +498,7 @@ def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> d transformed_result[parameter.name] = 0 elif parameter.type == "bool": transformed_result[parameter.name] = False - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: transformed_result[parameter.name] = "" elif parameter.type.startswith("array"): transformed_result[parameter.name] = [] @@ -516,9 +516,9 @@ def extract_json(text): """ stack = [] for i, c in enumerate(text): - if c == "{" or c == "[": + if c in {"{", "["}: stack.append(c) - elif c == "}" or c == "]": + elif c in {"}", "]"}: # check if stack is empty if not stack: return text[:i] @@ -560,7 +560,7 @@ def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict: result[parameter.name] = 0 elif parameter.type == "bool": result[parameter.name] = False - elif parameter.type in ["string", "select"]: + elif parameter.type in {"string", "select"}: result[parameter.name] = "" return result diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index e55adfc1f40272..3b86b29cf8e302 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -163,10 +163,7 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) result = [] for response in tool_response: - if ( - response.type == ToolInvokeMessage.MessageType.IMAGE_LINK - or response.type == ToolInvokeMessage.MessageType.IMAGE - ): + if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: url = response.message ext = path.splitext(url)[1] mimetype = response.meta.get("mime_type", "image/jpeg") diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 6da1a6d39bfd57..05a73b09b7b8c0 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -158,7 +158,7 @@ def get_authorized_pages(self, access_token: str): page_icon = page_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: @@ -191,7 +191,7 @@ def get_authorized_pages(self, access_token: str): page_icon = database_result["icon"] if page_icon: icon_type = page_icon["type"] - if icon_type == "external" or icon_type == "file": + if icon_type in {"external", "file"}: url = page_icon[icon_type]["url"] icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: diff --git a/api/libs/rsa.py b/api/libs/rsa.py index a578bf3e5617f6..637bcc4a1dda61 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -4,9 +4,9 @@ from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes -import libs.gmpy2_pkcs10aep_cipher as gmpy2_pkcs10aep_cipher from extensions.ext_redis import redis_client from extensions.ext_storage import storage +from libs import gmpy2_pkcs10aep_cipher def generate_key_pair(tenant_id): diff --git a/api/models/dataset.py b/api/models/dataset.py index 0da35910cd4b27..a2d2a3454d4931 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -284,9 +284,9 @@ def display_status(self): status = None if self.indexing_status == "waiting": status = "queuing" - elif self.indexing_status not in ["completed", "error", "waiting"] and self.is_paused: + elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: status = "paused" - elif self.indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: status = "indexing" elif self.indexing_status == "error": status = "error" @@ -331,7 +331,7 @@ def data_source_detail_dict(self): "created_at": file_detail.created_at.timestamp(), } } - elif self.data_source_type == "notion_import" or self.data_source_type == "website_crawl": + elif self.data_source_type in {"notion_import", "website_crawl"}: return json.loads(self.data_source_info) return {} diff --git a/api/models/model.py b/api/models/model.py index a8b2e00ee44004..ae0bc3210b6465 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -134,7 +134,7 @@ def is_agent(self) -> bool: return False if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get( "strategy", "" - ) in ["function_call", "react"]: + ) in {"function_call", "react"}: self.mode = AppMode.AGENT_CHAT.value db.session.commit() return True @@ -1501,6 +1501,6 @@ def to_dict(self): "tracing_provider": self.tracing_provider, "tracing_config": self.tracing_config_dict, "is_active": self.is_active, - "created_at": self.created_at.__str__() if self.created_at else None, - "updated_at": self.updated_at.__str__() if self.updated_at else None, + "created_at": str(self.created_at) if self.created_at else None, + "updated_at": str(self.updated_at) if self.updated_at else None, } diff --git a/api/services/account_service.py b/api/services/account_service.py index e839ae54bae722..66ff5d2b7ce7eb 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -47,7 +47,7 @@ def load_user(user_id: str) -> None | Account: if not account: return None - if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise Unauthorized("Account is banned or closed.") current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( @@ -92,7 +92,7 @@ def authenticate(email: str, password: str) -> Account: if not account: raise AccountLoginError("Invalid email or password.") - if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: + if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: @@ -427,7 +427,7 @@ def check_member_permission(tenant: Tenant, operator: Account, member: Account, "remove": [TenantAccountRole.OWNER], "update": [TenantAccountRole.OWNER], } - if action not in ["add", "remove", "update"]: + if action not in {"add", "remove", "update"}: raise InvalidActionError("Invalid action.") if member: diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 2fe39b522491a7..54594e1175de77 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -90,7 +90,7 @@ def import_and_create_new_app(cls, tenant_id: str, data: str, args: dict, accoun # import dsl and create app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -103,7 +103,7 @@ def import_and_create_new_app(cls, tenant_id: str, data: str, args: dict, accoun icon_background=icon_background, use_icon_as_answer_icon=use_icon_as_answer_icon, ) - elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, @@ -143,7 +143,7 @@ def import_and_overwrite_workflow(cls, app_model: App, data: str, account: Accou # import dsl and overwrite app app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: raise ValueError("Only support import workflow in advanced-chat or workflow app.") if app_data.get("mode") != app_model.mode: @@ -177,7 +177,7 @@ def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: }, } - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: cls._append_workflow_export_data( export_data=export_data, app_model=app_model, include_secret=include_secret ) diff --git a/api/services/app_service.py b/api/services/app_service.py index 1dacfea246398f..ac45d623e84bc9 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -316,7 +316,7 @@ def get_app_meta(self, app_model: App) -> dict: meta = {"tool_icons": {}} - if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: workflow = app_model.workflow if workflow is None: return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 05cd1c96a1d6a6..7a0cd5725b2a96 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -25,7 +25,7 @@ class AudioService: @classmethod def transcript_asr(cls, app_model: App, file: FileStorage, end_user: Optional[str] = None): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("Speech to text is not enabled") @@ -83,7 +83,7 @@ def transcript_tts( def invoke_tts(text_content: str, app_model, voice: Optional[str] = None): with app.app_context(): - if app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value]: + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: raise ValueError("TTS is not enabled") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 30e4ee57c06b45..afc491398f25f3 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -37,7 +37,7 @@ def _post_request(self, url, data, headers): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): - if response.status_code in [402, 409, 500]: + if response.status_code in {402, 409, 500}: error_message = response.json().get("error", "Unknown error occurred") raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index fa017bfa42a7c9..30c010ef29f623 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -544,7 +544,7 @@ def rename_document(dataset_id: str, document_id: str, name: str) -> Document: @staticmethod def pause_document(document): - if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]: + if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: raise DocumentIndexingError() # update document to be paused document.is_paused = True diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6fb0f2f517376d..7ae1b9f23173ee 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -33,7 +33,7 @@ def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str if provider_type == ToolProviderType.BUILT_IN.value: return url_prefix + "builtin/" + provider_name + "/icon" - elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: + elif provider_type in {ToolProviderType.API.value, ToolProviderType.WORKFLOW.value}: try: return json.loads(icon) except: diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 357ffd41c127a5..0ff81f1f7e834d 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -295,7 +295,7 @@ def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> A # chatbot convert to workflow mode workflow_converter = WorkflowConverter() - if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: + if app_model.mode not in {AppMode.CHAT.value, AppMode.COMPLETION.value}: raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index 21ea11d4dd7d2e..934eb7430c90c3 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -29,7 +29,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): try: indexing_runner = IndexingRunner() - if document.indexing_status in ["waiting", "parsing", "cleaning"]: + if document.indexing_status in {"waiting", "parsing", "cleaning"}: indexing_runner.run([document]) elif document.indexing_status == "splitting": indexing_runner.run_in_splitting_status(document) diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index bc0684086f0b20..402bd9c2c21f69 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,15 +1,13 @@ from collections.abc import Generator -import google.generativeai.types.content_types as content_types import google.generativeai.types.generation_types as generation_config_types -import google.generativeai.types.safety_types as safety_types import pytest from _pytest.monkeypatch import MonkeyPatch from google.ai import generativelanguage as glm from google.ai.generativelanguage_v1beta.types import content as gag_content from google.generativeai import GenerativeModel from google.generativeai.client import _ClientManager, configure -from google.generativeai.types import GenerateContentResponse +from google.generativeai.types import GenerateContentResponse, content_types, safety_types from google.generativeai.types.generation_types import BaseGenerateContentResponse current_api_key = "" diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index d9cd7b046e001c..439f7d56e9b5d5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -6,7 +6,6 @@ # import monkeypatch from typing import Any, Literal, Optional, Union -import openai.types.chat.completion_create_params as completion_create_params from openai import AzureOpenAI, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.chat.completions import Completions @@ -18,6 +17,7 @@ ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam, + completion_create_params, ) from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice @@ -254,7 +254,7 @@ def chat_create( "gpt-3.5-turbo-16k-0613", ] azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index c27e89248f4403..14223668e036d9 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -112,7 +112,7 @@ def completion_create( ] azure_openai_models = ["gpt-35-turbo-instruct"] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if model in openai_models + azure_openai_models: if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index 025913cb17a805..e27b9891f5c8a8 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -22,7 +22,7 @@ def create_embeddings( if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 270a88e85ffedd..4262d40f3e5464 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -20,7 +20,7 @@ def moderation_create( if isinstance(input, str): input = [input] - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index ef361e86139427..a51dcab4be7467 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -20,7 +20,7 @@ def speech2text_create( temperature: float | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> Transcription: - if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()): + if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)): raise InvokeAuthorizationError("Invalid base url") if len(self._client.api_key) < 18: diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index 777737187e259c..299523f4f5b7c2 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -42,7 +42,7 @@ def get(self: Session, url: str, **kwargs): model_uid = url.split("/")[-1] or "" if not re.match( r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid - ) and model_uid not in ["generate", "chat", "embedding", "rerank"]: + ) and model_uid not in {"generate", "chat", "embedding", "rerank"}: response.status_code = 404 response._content = b"{}" return response @@ -53,7 +53,7 @@ def get(self: Session, url: str, **kwargs): response._content = b"{}" return response - if model_uid in ["generate", "chat"]: + if model_uid in {"generate", "chat"}: response.status_code = 200 response._content = b"""{ "model_type": "LLM", diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index cbe9c5914f1f0d..88435c40227371 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -411,5 +411,5 @@ def test_chat_parameter_extractor_with_memory(setup_anthropic_mock): if latest_role is not None: assert latest_role != prompt.get("role") - if prompt.get("role") in ["user", "assistant"]: + if prompt.get("role") in {"user", "assistant"}: latest_role = prompt.get("role") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index a2d71d61fcef52..197288adba3ea1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -210,7 +210,7 @@ def llm_generator(self): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]: + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in {"llm2", "llm3", "end1", "end2"}: assert item.parallel_id is not None assert len(items) == 18 @@ -315,12 +315,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in { "answer2", "answer3", "answer4", "answer5", - ]: + }: assert item.parallel_id is not None assert len(items) == 23