Skip to content

Commit

Permalink
Add decorator for function calling (microsoft#1018)
Browse files Browse the repository at this point in the history
* add function decorator to converasble agent

* polishing

* polishing

* added function decorator to the notebook with async function calls

* added support for return type hint and JSON encoding of returned value if needed

* polishing

* polishing

* refactored async case

* Python 3.8 support added

* polishing

* polishing

* missing docs added

* refacotring and changes as requested

* getLogger

* documentation added

* test fix

* test fix

* added testing of agentchat_function_call_currency_calculator.ipynb to test_notebook.py

* added support for Pydantic parameters in function decorator

* polishing

* Update website/docs/Use-Cases/agent_chat.md

Co-authored-by: Li Jiang <bnujli@gmail.com>

* Update website/docs/Use-Cases/agent_chat.md

Co-authored-by: Li Jiang <bnujli@gmail.com>

* fixes problem with logprob parameter in openai.types.chat.chat_completion.Choice added by openai version 1.5.0

* get 100% code coverage on code added

* updated docs

* default values added to JSON schema

* serialization using json.dump() add for values not string or BaseModel

* added limit to openai version because of breaking changes in 1.5.0

* added line-by-line comments in docs to explain the process

* polishing

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
Co-authored-by: Li Jiang <bnujli@gmail.com>
  • Loading branch information
3 people authored Dec 25, 2023
1 parent b1adac5 commit 4b5ec5a
Show file tree
Hide file tree
Showing 19 changed files with 2,164 additions and 203 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ node_modules/
*.log

# Python virtualenv
.venv
.venv*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
110 changes: 110 additions & 0 deletions autogen/_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Any, Dict, Optional, Tuple, Type, Union, get_args

from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin

__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")

PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")

if not PYDANTIC_V1:
from pydantic import TypeAdapter
from pydantic._internal._typing_extra import eval_type_lenient as evaluate_forwardref
from pydantic.json_schema import JsonSchemaValue

def type2schema(t: Optional[Type]) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
t (Type): The type to convert
Returns:
JsonSchemaValue: The JSON schema
"""
return TypeAdapter(t).json_schema()

def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict
Args:
model (BaseModel): The model to convert
Returns:
Dict[str, Any]: The dict representation of the model
"""
return model.model_dump()

def model_dump_json(model: BaseModel) -> str:
"""Convert a pydantic model to a JSON string
Args:
model (BaseModel): The model to convert
Returns:
str: The JSON string representation of the model
"""
return model.model_dump_json()


# Remove this once we drop support for pydantic 1.x
else: # pragma: no cover
from pydantic import schema_of
from pydantic.typing import evaluate_forwardref as evaluate_forwardref

JsonSchemaValue = Dict[str, Any]

def type2schema(t: Optional[Type]) -> JsonSchemaValue:
"""Convert a type to a JSON schema
Args:
t (Type): The type to convert
Returns:
JsonSchemaValue: The JSON schema
"""
if PYDANTIC_V1:
if t is None:
return {"type": "null"}
elif get_origin(t) is Union:
return {"anyOf": [type2schema(tt) for tt in get_args(t)]}
elif get_origin(t) in [Tuple, tuple]:
prefixItems = [type2schema(tt) for tt in get_args(t)]
return {
"maxItems": len(prefixItems),
"minItems": len(prefixItems),
"prefixItems": prefixItems,
"type": "array",
}

d = schema_of(t)
if "title" in d:
d.pop("title")
if "description" in d:
d.pop("description")

return d

def model_dump(model: BaseModel) -> Dict[str, Any]:
"""Convert a pydantic model to a dict
Args:
model (BaseModel): The model to convert
Returns:
Dict[str, Any]: The dict representation of the model
"""
return model.dict()

def model_dump_json(model: BaseModel) -> str:
"""Convert a pydantic model to a JSON string
Args:
model (BaseModel): The model to convert
Returns:
str: The JSON string representation of the model
"""
return model.json()
4 changes: 3 additions & 1 deletion autogen/agentchat/contrib/math_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
from time import sleep

from autogen._pydantic import PYDANTIC_V1
from autogen.agentchat import Agent, UserProxyAgent
from autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
from autogen.math_utils import get_answer
Expand Down Expand Up @@ -384,7 +385,8 @@ class WolframAlphaAPIWrapper(BaseModel):
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
if PYDANTIC_V1:
extra = Extra.forbid

@root_validator(skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
167 changes: 163 additions & 4 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import asyncio
import copy
import functools
import inspect
import json
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union

from autogen import OpenAIWrapper
from autogen.code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union

from .. import OpenAIWrapper
from ..code_utils import DEFAULT_MODEL, UNKNOWN, content_str, execute_code, extract_code, infer_lang
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
from .agent import Agent

try:
Expand All @@ -19,8 +20,12 @@ def colored(x, *args, **kwargs):
return x


__all__ = ("ConversableAgent",)

logger = logging.getLogger(__name__)

F = TypeVar("F", bound=Callable[..., Any])


class ConversableAgent(Agent):
"""(In preview) A class for generic conversable agents which can be configured as assistant or user proxy.
Expand Down Expand Up @@ -1330,3 +1335,157 @@ def can_execute_function(self, name: str) -> bool:
def function_map(self) -> Dict[str, Callable]:
"""Return the function map."""
return self._function_map

def _wrap_function(self, func: F) -> F:
"""Wrap the function to dump the return value to json.
Handles both sync and async functions.
Args:
func: the function to be wrapped.
Returns:
The wrapped function.
"""

@load_basemodels_if_needed
@functools.wraps(func)
def _wrapped_func(*args, **kwargs):
retval = func(*args, **kwargs)

return serialize_to_str(retval)

@load_basemodels_if_needed
@functools.wraps(func)
async def _a_wrapped_func(*args, **kwargs):
retval = await func(*args, **kwargs)
return serialize_to_str(retval)

wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func

# needed for testing
wrapped_func._origin = func

return wrapped_func

def register_for_llm(
self,
*,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Callable[[F], F]:
"""Decorator factory for registering a function to be used by an agent.
It's return value is used to decorate a function to be registered to the agent. The function uses type hints to
specify the arguments and return type. The function name is used as the default name for the function,
but a custom name can be provided. The function description is used to describe the function in the
agent's configuration.
Args:
name (optional(str)): name of the function. If None, the function name will be used (default: None).
description (optional(str)): description of the function (default: None). It is mandatory
for the initial decorator, but the following ones can omit it.
Returns:
The decorator for registering a function to be used by an agent.
Examples:
```
@user_proxy.register_for_execution()
@agent2.register_for_llm()
@agent1.register_for_llm(description="This is a very useful function")
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14) -> str:
return a + str(b * c)
```
"""

def _decorator(func: F) -> F:
"""Decorator for registering a function to be used by an agent.
Args:
func: the function to be registered.
Returns:
The function to be registered, with the _description attribute set to the function description.
Raises:
ValueError: if the function description is not provided and not propagated by a previous decorator.
RuntimeError: if the LLM config is not set up before registering a function.
"""
# name can be overwriten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__

# description is propagated from the previous decorator, but it is mandatory for the first one
if description:
func._description = description
else:
if not hasattr(func, "_description"):
raise ValueError("Function description is required, none found.")

# get JSON schema for the function
f = get_function_schema(func, name=func._name, description=func._description)

# register the function to the agent if there is LLM config, raise an exception otherwise
if self.llm_config is None:
raise RuntimeError("LLM config must be setup before registering a function for LLM.")

self.update_function_signature(f, is_remove=False)

return func

return _decorator

def register_for_execution(
self,
name: Optional[str] = None,
) -> Callable[[F], F]:
"""Decorator factory for registering a function to be executed by an agent.
It's return value is used to decorate a function to be registered to the agent.
Args:
name (optional(str)): name of the function. If None, the function name will be used (default: None).
Returns:
The decorator for registering a function to be used by an agent.
Examples:
```
@user_proxy.register_for_execution()
@agent2.register_for_llm()
@agent1.register_for_llm(description="This is a very useful function")
def my_function(a: Annotated[str, "description of a parameter"] = "a", b: int, c=3.14):
return a + str(b * c)
```
"""

def _decorator(func: F) -> F:
"""Decorator for registering a function to be used by an agent.
Args:
func: the function to be registered.
Returns:
The function to be registered, with the _description attribute set to the function description.
Raises:
ValueError: if the function description is not provided and not propagated by a previous decorator.
"""
# name can be overwriten by the parameter, by default it is the same as function name
if name:
func._name = name
elif not hasattr(func, "_name"):
func._name = func.__name__

self.register_function({func._name: self._wrap_function(func)})

return func

return _decorator
Loading

0 comments on commit 4b5ec5a

Please sign in to comment.