diff --git a/docs/02-streaming.ipynb b/docs/02-streaming.ipynb index a5756d1..7eda07e 100644 --- a/docs/02-streaming.ipynb +++ b/docs/02-streaming.ipynb @@ -13,7 +13,7 @@ "source": [ "import os\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\"" + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" ] }, { @@ -49,9 +49,22 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/pydantic/_internal/_config.py:341: UserWarning: Valid config keys have changed in V2:\n", + "* 'allow_population_by_field_name' has been renamed to 'populate_by_name'\n", + "* 'smart_union' has been removed\n", + " warnings.warn(message, UserWarning)\n", + "/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import json\n", "import openai\n", @@ -197,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -208,7 +221,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -233,24 +246,16 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 7, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Graph compiled successfully.\n" - ] - } - ], + "outputs": [], "source": [ "graph.compile()" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -266,7 +271,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -282,15 +287,15 @@ { "data": { "text/plain": [ - "{'choice': 'finalanswer',\n", - " 'input': {'text': 'The user is in Bali right now.',\n", - " 'query': 'user location',\n", + "{'input': {'text': 'The user is in Bali right now.',\n", + " 'query': 'Where am I?',\n", " 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}],\n", - " 'answer': 'You are currently in Bali.',\n", - " 'sources': 'User context provided.'}}" + " 'answer': 'You are currently in Bali, enjoying the beautiful landscapes and vibrant culture of this popular tropical destination.',\n", + " 'sources': 'User indicated they are in Bali.'},\n", + " 'choice': 'finalanswer'}" ] }, - "execution_count": 22, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -314,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -323,28 +328,12 @@ "text": [ ">>> node_start\n", ">>> node_router\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/typing.py:409: RuntimeWarning: coroutine 'AsyncCompletions.create' was never awaited\n", - " ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)\n", - "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "\n", "\n", "{\"\n", "query\n", "\":\"\n", - "a\n", - " long\n", + "long\n", " story\n", "\"}\n", ">>> memory\n", @@ -358,94 +347,66 @@ "{\"\n", "answer\n", "\":\"\n", - "B\n", - "ali\n", - ",\n", - " renowned\n", - " for\n", - " its\n", - " breathtaking\n", - " beaches\n", - ",\n", - " vibrant\n", - " culture\n", - ",\n", - " and\n", - " lush\n", - " landscapes\n", - ",\n", - " is\n", - " truly\n", - " a\n", - " magical\n", - " place\n", - ".\n", - " Whether\n", - " you\n", - "’re\n", - " exploring\n", + "A\n", + " \\\"\n", + "long\n", + " story\n", + "\\\"\n", + " can\n", + " often\n", + " involve\n", " the\n", - " rice\n", - " terraces\n", + " intric\n", + "acies\n", + " and\n", + " details\n", " of\n", - " U\n", - "bud\n", - ",\n", - " surfing\n", - " the\n", - " waves\n", - " at\n", - " K\n", - "uta\n", - " Beach\n", - ",\n", + " our\n", + " current\n", + " experiences\n", " or\n", - " indul\n", - "ging\n", + " feelings\n", + ".\n", + " If\n", + " you're\n", " in\n", - " the\n", - " local\n", - " cuisine\n", + " Bali\n", ",\n", - " each\n", - " experience\n", - " adds\n", + " perhaps\n", + " there's\n", + " a\n", + " beautiful\n", + " or\n", + " memorable\n", + " tale\n", + " connected\n", " to\n", - " the\n", - " richness\n", - " of\n", - " your\n", - " journey\n", - ".\n", - " What\n", - " specific\n", - " aspects\n", - " of\n", " your\n", " time\n", - " in\n", - " Bali\n", - " would\n", - " you\n", - " like\n", + " there\n", + "!\n", + " \n", + " Feel\n", + " free\n", " to\n", - " hear\n", - " a\n", - " long\n", - " story\n", + " share\n", + " more\n", " about\n", - "?\n", + " your\n", + " situation\n", + " if\n", + " you'd\n", + " like\n", + " -\n", + " I'm\n", + " here\n", + " to\n", + " listen\n", + "!\n", "\",\"\n", "sources\n", - "\":\"\n", - "User\n", - " context\n", - " stating\n", - " they\n", - " are\n", - " in\n", - " Bali\n", - ".\"\n", + "\":\n", + "\"\"\n", "}\n", "\n", "\n" diff --git a/graphai/graph.py b/graphai/graph.py index 46b5256..66a4887 100644 --- a/graphai/graph.py +++ b/graphai/graph.py @@ -1,17 +1,18 @@ -from typing import List +from typing import List, Dict, Any from graphai.nodes.base import _Node from graphai.callback import Callback from semantic_router.utils.logger import logger class Graph: - def __init__(self): + def __init__(self, max_steps: int = 10): self.nodes = [] self.edges = [] self.start_node = None self.end_nodes = [] self.Callback = Callback self.callback = None + self.max_steps = max_steps def add_node(self, node): self.nodes.append(node) @@ -52,43 +53,58 @@ def compile(self): raise Exception("No end nodes defined.") if not self._is_valid(): raise Exception("Graph is not valid.") - print("Graph compiled successfully.") def _is_valid(self): # Implement validation logic, e.g., checking for cycles, disconnected components, etc. return True + def _validate_output(self, output: Dict[str, Any], node_name: str): + if not isinstance(output, dict): + raise ValueError( + f"Expected dictionary output from node {node_name}. " + f"Instead, got {type(output)} from '{output}'." + ) + async def execute(self, input): # TODO JB: may need to add init callback here to init the queue on every new execution if self.callback is None: self.callback = self.get_callback() current_node = self.start_node state = input + steps = 0 while True: # we invoke the node here if current_node.stream: - if self.callback is None: - # TODO JB: can remove? - raise ValueError("No callback provided to graph. Please add it via `.add_callback`.") # add callback tokens and param here if we are streaming await self.callback.start_node(node_name=current_node.name) - state = await current_node.invoke(input=state, callback=self.callback) + output = await current_node.invoke(input=state, callback=self.callback) + self._validate_output(output=output, node_name=current_node.name) await self.callback.end_node(node_name=current_node.name) else: - state = await current_node.invoke(input=state) + output = await current_node.invoke(input=state) + self._validate_output(output=output, node_name=current_node.name) if current_node.is_router: # if we have a router node we let the router decide the next node - next_node_name = str(state["choice"]) - del state["choice"] - current_node = self._get_node_by_name(next_node_name) + next_node_name = str(output["choice"]) + del output["choice"] + current_node = self._get_node_by_name(node_name=next_node_name) else: # otherwise, we have linear path - current_node = self._get_next_node(current_node, state) + current_node = self._get_next_node(current_node=current_node) + # add output to state + state = {**state, **output} if current_node.is_end: break + steps += 1 + if steps >= self.max_steps: + raise Exception( + f"Max steps reached: {self.max_steps}. You can modify this " + "by setting `max_steps` when initializing the Graph object." + ) # TODO JB: may need to add end callback here to close the queue for every execution if self.callback: await self.callback.close() + del state["callback"] return state def get_callback(self): @@ -101,7 +117,7 @@ def _get_node_by_name(self, node_name: str) -> _Node: return node raise Exception(f"Node with name {node_name} not found.") - def _get_next_node(self, current_node, output): + def _get_next_node(self, current_node): for edge in self.edges: if edge.source == current_node: return edge.destination diff --git a/graphai/nodes/base.py b/graphai/nodes/base.py index 78a665a..50cf824 100644 --- a/graphai/nodes/base.py +++ b/graphai/nodes/base.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, Optional from graphai.callback import Callback -from semantic_router.utils.logger import logger +from graphai.utils import FunctionSchema class NodeMeta(type): @@ -38,29 +38,54 @@ def _node( raise ValueError("Node must be a callable function.") func_signature = inspect.signature(func) + schema = FunctionSchema(func) class NodeClass: _func_signature = func_signature is_router = None _stream = stream - def __init__(self, *args, **kwargs): - bound_args = self._func_signature.bind(*args, **kwargs) - bound_args.apply_defaults() - for name, value in bound_args.arguments.items(): - setattr(self, name, value) + def __init__(self): + self._expected_params = set(self._func_signature.parameters.keys()) + + async def execute(self, *args, **kwargs): + # Prepare arguments, including callback if stream is True + params_dict = await self._parse_params(*args, **kwargs) + return await func(**params_dict) # Pass only the necessary arguments + + async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]: + # filter out unexpected keyword args + expected_kwargs = {k: v for k, v in kwargs.items() if k in self._expected_params} + # Convert args to kwargs based on the function signature + args_names = list(self._func_signature.parameters.keys())[1:len(args)+1] # skip 'self' + expected_args_kwargs = dict(zip(args_names, args)) + # Combine filtered args and kwargs + combined_params = {**expected_args_kwargs, **expected_kwargs} - async def execute(self): # Bind the current instance attributes to the function signature - if "callback" in self.__dict__.keys() and not stream: + if "callback" in self._expected_params and not stream: raise ValueError( f"Node {func.__name__}: requires stream=True when callback is defined." ) - bound_args = self._func_signature.bind(**self.__dict__) - bound_args.apply_defaults() - # Prepare arguments, including callback if stream is True - args_dict = bound_args.arguments.copy() # Copy arguments to modify safely - return await func(**args_dict) # Pass only the necessary arguments + bound_params = self._func_signature.bind_partial(**combined_params) + # get the default parameters (if any) + bound_params.apply_defaults() + params_dict = bound_params.arguments.copy() + # Filter arguments to match the next node's parameters + filtered_params = { + k: v for k, v in params_dict.items() if k in self._expected_params + } + # confirm all required parameters are present + missing_params = [ + p for p in self._expected_params if p not in filtered_params + ] + # if anything is missing we raise an error + if missing_params: + raise ValueError( + f"Missing required parameters for the {func.__name__} node: {', '.join(missing_params)}" + ) + return filtered_params + @classmethod def get_signature(cls): @@ -87,8 +112,8 @@ async def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None raise ValueError( f"Error in node {func.__name__}. When callback provided, stream must be True." ) - instance = cls(**input) - out = await instance.execute() + instance = cls() + out = await instance.execute(**input) return out NodeClass.__name__ = func.__name__ @@ -98,7 +123,7 @@ async def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None NodeClass.is_end = end NodeClass.is_router = self.is_router NodeClass.stream = stream - + NodeClass.schema = schema return NodeClass def __call__( diff --git a/graphai/utils.py b/graphai/utils.py new file mode 100644 index 0000000..51a818e --- /dev/null +++ b/graphai/utils.py @@ -0,0 +1,125 @@ +import inspect +from typing import Any, Callable, Dict, List, Union, Optional +from pydantic import BaseModel, Field + + +class Parameter(BaseModel): + class Config: + arbitrary_types_allowed = True + + name: str = Field(description="The name of the parameter") + description: Optional[str] = Field( + default=None, description="The description of the parameter" + ) + type: str = Field(description="The type of the parameter") + default: Any = Field(description="The default value of the parameter") + required: bool = Field(description="Whether the parameter is required") + + def to_openai(self): + return { + self.name: { + "description": self.description, + "type": self.type, + } + } + +class FunctionSchema: + """Class that consumes a function and can return a schema required by + different LLMs for function calling. + """ + + name: str = Field(description="The name of the function") + description: str = Field(description="The description of the function") + signature: str = Field(description="The signature of the function") + output: str = Field(description="The output of the function") + parameters: List[Parameter] = Field(description="The parameters of the function") + + def __init__(self, function: Union[Callable, BaseModel]): + self.function = function + if callable(function): + self._process_function(function) + elif isinstance(function, BaseModel): + raise NotImplementedError("Pydantic BaseModel not implemented yet.") + else: + raise TypeError("Function must be a Callable or BaseModel") + + def _process_function(self, function: Callable): + self.name = function.__name__ + self.description = str(inspect.getdoc(function)) + self.signature = str(inspect.signature(function)) + self.output = str(inspect.signature(function).return_annotation) + parameters = [] + for param in inspect.signature(function).parameters.values(): + parameters.append( + Parameter( + name=param.name, + type=param.annotation.__name__, + default=param.default, + required=param.default is inspect.Parameter.empty, + ) + ) + self.parameters = parameters + + def to_openai(self): + schema_dict = { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": { + param.name: { + "description": ( + param.description + if isinstance(param.description, str) + else "None provided" + ), + "type": self._openai_type_mapping(param.type), + } + for param in self.parameters + }, + "required": [ + param.name for param in self.parameters if param.required + ], + }, + }, + } + return schema_dict + + def _openai_type_mapping(self, param_type: str) -> str: + if param_type == "int": + return "number" + elif param_type == "float": + return "number" + elif param_type == "str": + return "string" + elif param_type == "bool": + return "boolean" + else: + return "object" + + +def get_schema_pydantic(model: BaseModel) -> Dict[str, Any]: + signature_parts = [] + for field_name, field_model in model.__annotations__.items(): + field_info = model.__fields__[field_name] + default_value = field_info.default + + if default_value: + default_repr = repr(default_value) + signature_part = ( + f"{field_name}: {field_model.__name__} = {default_repr}" + ) + else: + signature_part = f"{field_name}: {field_model.__name__}" + + signature_parts.append(signature_part) + signature = f"({', '.join(signature_parts)}) -> str" + schema = FunctionSchema( + name=model.__class__.__name__, + description=model.__doc__, + signature=signature, + output="", # TODO: Implement output + ) + return schema