Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend: Check tool parameters generated by model #849

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
TLK-2091 check tool parameters generated by model - add a recursive t…
…ype checker.
  • Loading branch information
EugeneLightsOn committed Nov 20, 2024
commit 289dbffabf41741040ef668b317e2b08e4dd626f
47 changes: 44 additions & 3 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, get_args, get_origin

from fastapi import Request

Expand All @@ -13,6 +13,48 @@

logger = LoggerFactory().get_logger()

def check_type(param_value, type_description: str) -> bool:
EugeneLightsOn marked this conversation as resolved.
Show resolved Hide resolved
try:
# Convert the type string into a type object
expected_type = eval(type_description)
EugeneLightsOn marked this conversation as resolved.
Show resolved Hide resolved
return _check_type_recursive(param_value, expected_type)
except Exception as e:
print(f"Error during type checking: {e}")
return False

def _check_type_recursive(value, expected_type) -> bool:
origin = get_origin(expected_type)

if origin is None: # Base types (int, str, ...)
return isinstance(value, expected_type)

if origin is list: # Check if the value is a list
if not isinstance(value, list):
return False
element_type = get_args(expected_type)[0]
return all(_check_type_recursive(item, element_type) for item in value)

if origin is tuple: # Tuples
# trying to help to model with tuple type by converting lists to tuples, Cohere model passed tuples as list
converted_value = tuple(value) if isinstance(value, list) else value
if not isinstance(converted_value, tuple) or len(converted_value) != len(get_args(expected_type)):
return False
return all(
_check_type_recursive(item, arg_type)
for item, arg_type in zip(value, get_args(expected_type))
)

if origin is dict: # Dictionaries
if not isinstance(value, dict):
return False
key_type, value_type = get_args(expected_type)
return all(
_check_type_recursive(k, key_type) and _check_type_recursive(v, value_type)
for k, v in value.items()
)

# NOTE: Maybe we need to handle more types in the future, depends on the use in tools and models
return False

def check_tool_parameters(tool_definition: ToolDefinition) -> None:
def decorator(func):
Expand All @@ -27,10 +69,9 @@ def wrapper(self, *args, **kwargs):
raise ValueError(f"Model didn't pass required parameter: {param}")
else:
value = passed_method_params[param]
expected_type = eval(rules["type"])
if not value and is_required:
raise ValueError(f"Model passed empty value for required parameter: {param}")
if not isinstance(value, expected_type):
if not check_type(value, rules["type"]):
raise TypeError(
f"Model passed invalid parameter. Parameter '{param}' must be of type {rules['type']}, but got {type(value).__name__}"
)
Expand Down
Loading