Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend: Check tool parameters generated by model #849

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
TLK-2091 check tool parameters generated by model
  • Loading branch information
EugeneLightsOn committed Nov 19, 2024
commit 0ac0ac4ad0d70aa610e62dff036ba4dbe06a3e08
75 changes: 60 additions & 15 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,46 @@
logger = LoggerFactory().get_logger()


class BaseTool():
def check_tool_parameters(tool_definition: ToolDefinition) -> None:
def decorator(func):
def wrapper(self, *args, **kwargs):
parameter_definitions = tool_definition(self).parameter_definitions
passed_method_params = kwargs.get("parameters", {})
# Validate parameters
for param, rules in parameter_definitions.items():
is_required = rules.get("required", False)
if param not in passed_method_params:
if is_required:
raise ValueError(f"Model didn't pass required parameter: {param}")
else:
value = passed_method_params[param]
expected_type = eval(rules["type"])
if not value and is_required:
raise ValueError(f"Model passed empty value for required parameter: {param}")
if not isinstance(value, expected_type):
raise TypeError(
f"Model passed invalid parameter. Parameter '{param}' must be of type {rules['type']}, but got {type(value).__name__}"
)

return func(self, *args, **kwargs)

return wrapper

return decorator


class ParametersCheckingMeta(type):
EugeneLightsOn marked this conversation as resolved.
Show resolved Hide resolved
def __new__(cls, name, bases, dct):
EugeneLightsOn marked this conversation as resolved.
Show resolved Hide resolved
for attr_name, attr_value in dct.items():
if callable(attr_value) and attr_name == "call":
# Decorate methods with the parameter checker
dct[attr_name] = check_tool_parameters(
lambda self: self.__class__.get_tool_definition()
)(attr_value)
return super().__new__(cls, name, bases, dct)


class BaseTool(metaclass=ParametersCheckingMeta):
"""
Abstract base class for all Tools.

Expand All @@ -32,11 +71,13 @@ def _post_init_check(self):

@classmethod
@abstractmethod
def is_available(cls) -> bool: ...
def is_available(cls) -> bool:
...

@classmethod
@abstractmethod
def get_tool_definition(cls) -> ToolDefinition: ...
def get_tool_definition(cls) -> ToolDefinition:
...

@classmethod
def generate_error_message(cls) -> str | None:
Expand All @@ -51,8 +92,9 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None:

@abstractmethod
async def call(
self, parameters: dict, ctx: Any, **kwargs: Any
) -> List[Dict[str, Any]]: ...
self, parameters: dict, ctx: Any, **kwargs: Any
) -> List[Dict[str, Any]]:
...


class BaseToolAuthentication(ABC):
Expand All @@ -69,19 +111,20 @@ def __init__(self, *args, **kwargs):

def _post_init_check(self):
if any(
[
self.BACKEND_HOST is None,
self.FRONTEND_HOST is None,
self.AUTH_SECRET_KEY is None,
]
[
self.BACKEND_HOST is None,
self.FRONTEND_HOST is None,
self.AUTH_SECRET_KEY is None,
]
):
raise ValueError(
"Tool Authentication requires auth.backend_hostname, auth.frontend_hostname in configuration.yaml, "
"and auth.secret_key in the secrets.yaml configuration files."
)

@abstractmethod
def get_auth_url(self, user_id: str) -> str: ...
def get_auth_url(self, user_id: str) -> str:
...

def is_auth_required(self, session: DBSessionDep, user_id: str) -> bool:
auth = tool_auth_crud.get_tool_auth(session, self.TOOL_ID, user_id)
Expand Down Expand Up @@ -114,13 +157,15 @@ def is_auth_required(self, session: DBSessionDep, user_id: str) -> bool:

@abstractmethod
def try_refresh_token(
self, session: DBSessionDep, user_id: str, tool_auth: ToolAuth
) -> bool: ...
self, session: DBSessionDep, user_id: str, tool_auth: ToolAuth
) -> bool:
...

@abstractmethod
def retrieve_auth_token(
self, request: Request, session: DBSessionDep, user_id: str
) -> str: ...
self, request: Request, session: DBSessionDep, user_id: str
) -> str:
...

def get_token(self, session: DBSessionDep, user_id: str) -> str:
tool_auth = tool_auth_crud.get_tool_auth(session, self.TOOL_ID, user_id)
Expand Down
Loading