From c76c6ecd31d95d056ecf5eaaa12641fac4d7cf2a Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 23 Apr 2024 15:22:42 +0800 Subject: [PATCH] refactor: tool parameter cache (#3703) --- api/controllers/console/app/app.py | 47 +------------- api/controllers/console/app/model_config.py | 12 +++- api/core/agent/base_agent_runner.py | 1 + api/core/helper/tool_parameter_cache.py | 11 ++-- api/core/tools/tool_manager.py | 6 +- api/core/tools/utils/configuration.py | 11 +++- api/core/workflow/nodes/tool/tool_node.py | 3 +- ...rameters_cache_when_sync_draft_workflow.py | 1 + api/services/app_service.py | 62 +++++++++++++++++++ 9 files changed, 96 insertions(+), 58 deletions(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c694cc7fc3f232..fb9c2c23ca60fc 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,5 +1,3 @@ -import json - from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse from werkzeug.exceptions import BadRequest, Forbidden @@ -8,17 +6,12 @@ from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.agent.entities import AgentToolEntity -from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ToolParameterConfigurationManager -from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, ) from libs.login import login_required -from models.model import App, AppMode, AppModelConfig from services.app_service import AppService ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] @@ -108,43 +101,9 @@ class AppApi(Resource): @marshal_with(app_detail_fields_with_site) def get(self, app_model): """Get app detail""" - # get original app model config - if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: - model_config: AppModelConfig = app_model.app_model_config - agent_mode = model_config.agent_mode_dict - # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: - if not isinstance(tool, dict) or len(tool.keys()) <= 3: - continue - agent_tool_entity = AgentToolEntity(**tool) - # get tool - try: - tool_runtime = ToolManager.get_agent_tool_runtime( - tenant_id=current_user.current_tenant_id, - agent_tool=agent_tool_entity, - ) - manager = ToolParameterConfigurationManager( - tenant_id=current_user.current_tenant_id, - tool_runtime=tool_runtime, - provider_name=agent_tool_entity.provider_id, - provider_type=agent_tool_entity.provider_type, - ) - - # get decrypted parameters - if agent_tool_entity.tool_parameters: - parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) - masked_parameter = manager.mask_tool_parameters(parameters or {}) - else: - masked_parameter = {} - - # override tool parameters - tool['tool_parameters'] = masked_parameter - except Exception as e: - pass - - # override agent mode - model_config.agent_mode = json.dumps(agent_mode) - db.session.commit() + app_service = AppService() + + app_model = app_service.get_app(app_model) return app_model diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index a7eaee346015c1..05d7958f7df819 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -57,6 +57,7 @@ def post(self, app_model): try: tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, + app_id=app_model.id, agent_tool=agent_tool_entity, ) manager = ToolParameterConfigurationManager( @@ -64,6 +65,7 @@ def post(self, app_model): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, + identity_id=f'AGENT.{app_model.id}' ) except Exception as e: continue @@ -94,6 +96,7 @@ def post(self, app_model): try: tool_runtime = ToolManager.get_agent_tool_runtime( tenant_id=current_user.current_tenant_id, + app_id=app_model.id, agent_tool=agent_tool_entity, ) except Exception as e: @@ -104,6 +107,7 @@ def post(self, app_model): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, + identity_id=f'AGENT.{app_model.id}' ) manager.delete_tool_parameters_cache() @@ -111,9 +115,11 @@ def post(self, app_model): if agent_tool_entity.tool_parameters: if key not in masked_parameter_map: continue - - if agent_tool_entity.tool_parameters == masked_parameter_map[key]: - agent_tool_entity.tool_parameters = parameter_map[key] + + for masked_key, masked_value in masked_parameter_map[key].items(): + if masked_key in agent_tool_entity.tool_parameters and \ + agent_tool_entity.tool_parameters[masked_key] == masked_value: + agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) # encrypt parameters if agent_tool_entity.tool_parameters: diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index e5b4b9a4cdf72f..485633cab1b3d5 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -163,6 +163,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P """ tool_entity = ToolManager.get_agent_tool_runtime( tenant_id=self.tenant_id, + app_id=self.app_config.app_id, agent_tool=tool, ) tool_entity.load_variables(self.variables_pool) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index db05eb18750636..a6f486e81de006 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum): class ToolParameterCache: def __init__(self, - tenant_id: str, - provider: str, - tool_name: str, - cache_type: ToolParameterCacheType + tenant_id: str, + provider: str, + tool_name: str, + cache_type: ToolParameterCacheType, + identity_id: str ): - self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}" def get(self) -> Optional[dict]: """ diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 61e29672f92496..a29bdfcd11f81e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -222,7 +222,7 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict return parameter_value @classmethod - def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool: + def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool: """ get the agent tool runtime """ @@ -245,6 +245,7 @@ def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> tool_runtime=tool_entity, provider_name=agent_tool.provider_id, provider_type=agent_tool.provider_type, + identity_id=f'AGENT.{app_id}' ) runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) @@ -252,7 +253,7 @@ def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity): + def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity): """ get the workflow tool runtime """ @@ -277,6 +278,7 @@ def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity): tool_runtime=tool_entity, provider_name=workflow_tool.provider_id, provider_type=workflow_tool.provider_type, + identity_id=f'WORKFLOW.{app_id}.{node_id}' ) if runtime_parameters: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 619e7ffd619dff..917f8411c48c36 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -113,12 +113,13 @@ class ToolParameterConfigurationManager(BaseModel): tool_runtime: Tool provider_name: str provider_type: str + identity_id: str def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: """ deep copy parameters """ - return {key: value for key, value in parameters.items()} + return deepcopy(parameters) def _merge_parameters(self) -> list[ToolParameter]: """ @@ -176,6 +177,8 @@ def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: # override parameters current_parameters = self._merge_parameters() + parameters = self._deep_copy(parameters) + for parameter in current_parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: if parameter.name in parameters: @@ -194,7 +197,8 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, - cache_type=ToolParameterCacheType.PARAMETER + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id ) cached_parameters = cache.get() if cached_parameters: @@ -223,7 +227,8 @@ def delete_tool_parameters_cache(self): tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, - cache_type=ToolParameterCacheType.PARAMETER + cache_type=ToolParameterCacheType.PARAMETER, + identity_id=self.identity_id ) cache.delete() diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 8a67284971cc20..d183dbe17b2699 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -39,7 +39,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult: parameters = self._generate_parameters(variable_pool, node_data) # get tool runtime try: - tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data) + self.app_id + tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data) except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 1f631be1ccb111..2a127d903e4320 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -22,5 +22,6 @@ def handle(sender, **kwargs): tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, provider_type=tool_entity.provider_type, + identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}' ) manager.delete_tool_parameters_cache() diff --git a/api/services/app_service.py b/api/services/app_service.py index c2f7cbb02c424c..f57da12cf8180f 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -5,13 +5,17 @@ import yaml from flask import current_app +from flask_login import current_user from flask_sqlalchemy.pagination import Pagination from constants.model_template import default_app_templates +from core.agent.entities import AgentToolEntity from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted from extensions.ext_database import db from models.account import Account @@ -240,6 +244,64 @@ def export_app(self, app: App) -> str: return yaml.dump(export_data) + def get_app(self, app: App) -> App: + """ + Get App + """ + # get original app model config + if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: + model_config: AppModelConfig = app.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get('tools') or []: + if not isinstance(tool, dict) or len(tool.keys()) <= 3: + continue + agent_tool_entity = AgentToolEntity(**tool) + # get tool + try: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + app_id=app.id, + agent_tool=agent_tool_entity, + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + identity_id=f'AGENT.{app.id}' + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} + + # override tool parameters + tool['tool_parameters'] = masked_parameter + except Exception as e: + pass + + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) + + class ModifiedApp(App): + """ + Modified App class + """ + def __init__(self, app): + self.__dict__.update(app.__dict__) + + @property + def app_model_config(self): + return model_config + + app = ModifiedApp(app) + + return app + def update_app(self, app: App, args: dict) -> App: """ Update app