Skip to content

Commit

Permalink
refactor: tool parameter cache (langgenius#3703)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly authored Apr 23, 2024
1 parent 3696737 commit 3268c7e
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 58 deletions.
47 changes: 3 additions & 44 deletions api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

from flask_login import current_user
from flask_restful import Resource, inputs, marshal_with, reqparse
from werkzeug.exceptions import BadRequest, Forbidden
Expand All @@ -8,17 +6,12 @@
from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from core.agent.entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from extensions.ext_database import db
from fields.app_fields import (
app_detail_fields,
app_detail_fields_with_site,
app_pagination_fields,
)
from libs.login import login_required
from models.model import App, AppMode, AppModelConfig
from services.app_service import AppService

ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion']
Expand Down Expand Up @@ -108,43 +101,9 @@ class AppApi(Resource):
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
# get original app model config
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
model_config: AppModelConfig = app_model.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)

# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}

# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass

# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
db.session.commit()
app_service = AppService()

app_model = app_service.get_app(app_model)

return app_model

Expand Down
12 changes: 9 additions & 3 deletions api/controllers/console/app/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def post(self, app_model):
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
)
except Exception as e:
continue
Expand Down Expand Up @@ -94,6 +96,7 @@ def post(self, app_model):
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
agent_tool=agent_tool_entity,
)
except Exception as e:
Expand All @@ -104,16 +107,19 @@ def post(self, app_model):
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app_model.id}'
)
manager.delete_tool_parameters_cache()

# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue

if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]

for masked_key, masked_value in masked_parameter_map[key].items():
if masked_key in agent_tool_entity.tool_parameters and \
agent_tool_entity.tool_parameters[masked_key] == masked_value:
agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key)

# encrypt parameters
if agent_tool_entity.tool_parameters:
Expand Down
1 change: 1 addition & 0 deletions api/core/agent/base_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[P
"""
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_config.app_id,
agent_tool=tool,
)
tool_entity.load_variables(self.variables_pool)
Expand Down
11 changes: 6 additions & 5 deletions api/core/helper/tool_parameter_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ class ToolParameterCacheType(Enum):

class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType,
identity_id: str
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}:identity_id:{identity_id}"

def get(self) -> Optional[dict]:
"""
Expand Down
6 changes: 4 additions & 2 deletions api/core/tools/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict
return parameter_value

@classmethod
def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
"""
get the agent tool runtime
"""
Expand All @@ -245,14 +245,15 @@ def get_agent_tool_runtime(cls, tenant_id: str, agent_tool: AgentToolEntity) ->
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
identity_id=f'AGENT.{app_id}'
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)

tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity

@classmethod
def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
"""
get the workflow tool runtime
"""
Expand All @@ -277,6 +278,7 @@ def get_workflow_tool_runtime(cls, tenant_id: str, workflow_tool: ToolEntity):
tool_runtime=tool_entity,
provider_name=workflow_tool.provider_id,
provider_type=workflow_tool.provider_type,
identity_id=f'WORKFLOW.{app_id}.{node_id}'
)

if runtime_parameters:
Expand Down
11 changes: 8 additions & 3 deletions api/core/tools/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ class ToolParameterConfigurationManager(BaseModel):
tool_runtime: Tool
provider_name: str
provider_type: str
identity_id: str

def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return {key: value for key, value in parameters.items()}
return deepcopy(parameters)

def _merge_parameters(self) -> list[ToolParameter]:
"""
Expand Down Expand Up @@ -176,6 +177,8 @@ def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
# override parameters
current_parameters = self._merge_parameters()

parameters = self._deep_copy(parameters)

for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
Expand All @@ -194,7 +197,8 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
)
cached_parameters = cache.get()
if cached_parameters:
Expand Down Expand Up @@ -223,7 +227,8 @@ def delete_tool_parameters_cache(self):
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
)
cache.delete()

Expand Down
3 changes: 2 additions & 1 deletion api/core/workflow/nodes/tool/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data)
self.app_id
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ def handle(sender, **kwargs):
tool_runtime=tool_runtime,
provider_name=tool_entity.provider_name,
provider_type=tool_entity.provider_type,
identity_id=f'WORKFLOW.{app.id}.{node_data.get("id")}'
)
manager.delete_tool_parameters_cache()
62 changes: 62 additions & 0 deletions api/services/app_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@

import yaml
from flask import current_app
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination

from constants.model_template import default_app_templates
from core.agent.entities import AgentToolEntity
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated, app_was_created, app_was_deleted
from extensions.ext_database import db
from models.account import Account
Expand Down Expand Up @@ -240,6 +244,64 @@ def export_app(self, app: App) -> str:

return yaml.dump(export_data)

def get_app(self, app: App) -> App:
"""
Get App
"""
# get original app model config
if app.mode == AppMode.AGENT_CHAT.value or app.is_agent:
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
app_id=app.id,
agent_tool=agent_tool_entity,
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
identity_id=f'AGENT.{app.id}'
)

# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}

# override tool parameters
tool['tool_parameters'] = masked_parameter
except Exception as e:
pass

# override agent mode
model_config.agent_mode = json.dumps(agent_mode)

class ModifiedApp(App):
"""
Modified App class
"""
def __init__(self, app):
self.__dict__.update(app.__dict__)

@property
def app_model_config(self):
return model_config

app = ModifiedApp(app)

return app

def update_app(self, app: App, args: dict) -> App:
"""
Update app
Expand Down

0 comments on commit 3268c7e

Please sign in to comment.