From c27f7ac1021b540565c5cf9b502381a757ad4412 Mon Sep 17 00:00:00 2001 From: Manthan Gupta Date: Mon, 23 Dec 2024 19:57:48 +0530 Subject: [PATCH] update --- cookbook/workflows/startup_idea_validator.py | 2 +- phi/agent/agent.py | 28 ++++++++------- phi/tools/function.py | 5 +-- phi/workflow/workflow.py | 38 ++++++++++++-------- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/cookbook/workflows/startup_idea_validator.py b/cookbook/workflows/startup_idea_validator.py index c4070f622..99bc89dfc 100644 --- a/cookbook/workflows/startup_idea_validator.py +++ b/cookbook/workflows/startup_idea_validator.py @@ -205,7 +205,7 @@ def run(self, startup_idea: str) -> Iterator[RunResponse]: table_name="validate_startup_ideas_workflow", db_file="tmp/workflows.db", ), - debug_mode=True + debug_mode=True, ) final_report: Iterator[RunResponse] = startup_idea_validator.run(startup_idea=idea) diff --git a/phi/agent/agent.py b/phi/agent/agent.py index fbcb139da..18020c49d 100644 --- a/phi/agent/agent.py +++ b/phi/agent/agent.py @@ -353,7 +353,7 @@ def _deep_copy_field(self, field_name: str, field_value: Any) -> Any: def has_team(self) -> bool: return self.team is not None and len(self.team) > 0 - + def has_workflows(self) -> bool: return self.workflows is not None and len(self.workflows) > 0 @@ -463,7 +463,7 @@ def get_transfer_prompt(self) -> str: transfer_prompt += f"Available tools: {', '.join(_tools)}\n" return transfer_prompt return "" - + def get_workflow_prompt(self) -> str: if self.has_workflows(): workflow_prompt = "## Available Workflows:" @@ -484,16 +484,20 @@ def get_workflow_function(self, workflow: Workflow, index: int) -> List[Function workflow_functions = [] for func_name, func in workflow._registered_functions.items(): - workflow_functions.append(Function( - name=f"{func_name}", - description=func.get("description") if func.get("description") else f"Use this function to run the {func_name} function of the {workflow_name} workflow", - parameters={"type": "object", "properties": func["parameters"], "required": func["required"]}, - entrypoint=func["function"], - class_instance=workflow, - sanitize_arguments=True, - show_result=False, - stop_after_tool_call=False - )) + workflow_functions.append( + Function( + name=f"{func_name}", + description=func.get("description") + if func.get("description") + else f"Use this function to run the {func_name} function of the {workflow_name} workflow", + parameters={"type": "object", "properties": func["parameters"], "required": func["required"]}, + entrypoint=func["function"], + class_instance=workflow, + sanitize_arguments=True, + show_result=False, + stop_after_tool_call=False, + ) + ) return workflow_functions diff --git a/phi/tools/function.py b/phi/tools/function.py index 050d5a7ee..7a816c38d 100644 --- a/phi/tools/function.py +++ b/phi/tools/function.py @@ -319,9 +319,10 @@ def execute(self) -> bool: if "fc" in signature(self.function.entrypoint).parameters: entrypoint_args["fc"] = self - if self.function.class_instance is not None: - self.result = self.function.entrypoint.__get__(self.function.class_instance)(**entrypoint_args, **self.arguments) + self.result = self.function.entrypoint.__get__(self.function.class_instance)( + **entrypoint_args, **self.arguments + ) else: self.result = self.function.entrypoint(**entrypoint_args, **self.arguments) diff --git a/phi/workflow/workflow.py b/phi/workflow/workflow.py index 6db4197c8..283166d01 100644 --- a/phi/workflow/workflow.py +++ b/phi/workflow/workflow.py @@ -456,15 +456,16 @@ def _deep_copy_field(self, field_name: str, field_value: Any) -> Any: # For other types, return as is return field_value - + @classmethod def register(cls, description: Optional[str] = None) -> Callable: """ Decorator to register functions within the workflow. - + Args: description: Optional description of what the function does """ + def decorator(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): @@ -472,34 +473,41 @@ def wrapper(*args, **kwargs): # Get function signature parameters type_mapping = { - str: 'string', - int: 'integer', - bool: 'boolean', - float: 'number', - list: 'array', - dict: 'object' + str: "string", + int: "integer", + bool: "boolean", + float: "number", + list: "array", + dict: "object", } sig = inspect.signature(func) params = { name: { - 'type': type_mapping.get(param.annotation, param.annotation.__name__) if param.annotation != inspect.Parameter.empty else None, + "type": type_mapping.get(param.annotation, param.annotation.__name__) + if param.annotation != inspect.Parameter.empty + else None, } for name, param in sig.parameters.items() - if name != 'self' + if name != "self" } # Ensure _registered_functions is initialized as a dictionary - if not hasattr(cls, '_registered_functions') or not isinstance(cls._registered_functions, dict): + if not hasattr(cls, "_registered_functions") or not isinstance(cls._registered_functions, dict): cls._registered_functions = {} # Update function metadata cls._registered_functions[func.__name__] = { - 'function': wrapper, - 'description': description or func.__doc__ or "No description provided", - 'parameters': params, - 'required': [name for name, param in sig.parameters.items() if param.default is inspect.Parameter.empty and name != 'self'], + "function": wrapper, + "description": description or func.__doc__ or "No description provided", + "parameters": params, + "required": [ + name + for name, param in sig.parameters.items() + if param.default is inspect.Parameter.empty and name != "self" + ], } return wrapper + return decorator