Skip to content

Commit

Permalink
Merge pull request #5 from aurelio-labs/james/better-tools
Browse files Browse the repository at this point in the history
feat: get func schemas, set max steps
  • Loading branch information
jamescalam authored Sep 12, 2024
2 parents b6e30fd + ce5dd1e commit bbfb89d
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 147 deletions.
197 changes: 79 additions & 118 deletions docs/02-streaming.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"source": [
"import os\n",
"\n",
"os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\""
"os.environ[\"OPENAI_API_KEY\"] = \"sk-...\""
]
},
{
Expand Down Expand Up @@ -49,9 +49,22 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/pydantic/_internal/_config.py:341: UserWarning: Valid config keys have changed in V2:\n",
"* 'allow_population_by_field_name' has been renamed to 'populate_by_name'\n",
"* 'smart_union' has been removed\n",
" warnings.warn(message, UserWarning)\n",
"/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import json\n",
"import openai\n",
Expand Down Expand Up @@ -197,7 +210,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -208,7 +221,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -233,24 +246,16 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Graph compiled successfully.\n"
]
}
],
"outputs": [],
"source": [
"graph.compile()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -266,7 +271,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -282,15 +287,15 @@
{
"data": {
"text/plain": [
"{'choice': 'finalanswer',\n",
" 'input': {'text': 'The user is in Bali right now.',\n",
" 'query': 'user location',\n",
"{'input': {'text': 'The user is in Bali right now.',\n",
" 'query': 'Where am I?',\n",
" 'chat_history': [{'role': 'user', 'content': 'do you remember where I am?'}],\n",
" 'answer': 'You are currently in Bali.',\n",
" 'sources': 'User context provided.'}}"
" 'answer': 'You are currently in Bali, enjoying the beautiful landscapes and vibrant culture of this popular tropical destination.',\n",
" 'sources': 'User indicated they are in Bali.'},\n",
" 'choice': 'finalanswer'}"
]
},
"execution_count": 22,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -314,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -323,28 +328,12 @@
"text": [
">>> node_start\n",
">>> node_router\n",
"<graphai:start:node_router>\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jamesbriggs/opt/anaconda3/envs/graphai/lib/python3.11/typing.py:409: RuntimeWarning: coroutine 'AsyncCompletions.create' was never awaited\n",
" ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)\n",
"RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<graphai:start:node_router>\n",
"<graphai:toolname:memory>\n",
"{\"\n",
"query\n",
"\":\"\n",
"a\n",
" long\n",
"long\n",
" story\n",
"\"}\n",
">>> memory\n",
Expand All @@ -358,94 +347,66 @@
"{\"\n",
"answer\n",
"\":\"\n",
"B\n",
"ali\n",
",\n",
" renowned\n",
" for\n",
" its\n",
" breathtaking\n",
" beaches\n",
",\n",
" vibrant\n",
" culture\n",
",\n",
" and\n",
" lush\n",
" landscapes\n",
",\n",
" is\n",
" truly\n",
" a\n",
" magical\n",
" place\n",
".\n",
" Whether\n",
" you\n",
"’re\n",
" exploring\n",
"A\n",
" \\\"\n",
"long\n",
" story\n",
"\\\"\n",
" can\n",
" often\n",
" involve\n",
" the\n",
" rice\n",
" terraces\n",
" intric\n",
"acies\n",
" and\n",
" details\n",
" of\n",
" U\n",
"bud\n",
",\n",
" surfing\n",
" the\n",
" waves\n",
" at\n",
" K\n",
"uta\n",
" Beach\n",
",\n",
" our\n",
" current\n",
" experiences\n",
" or\n",
" indul\n",
"ging\n",
" feelings\n",
".\n",
" If\n",
" you're\n",
" in\n",
" the\n",
" local\n",
" cuisine\n",
" Bali\n",
",\n",
" each\n",
" experience\n",
" adds\n",
" perhaps\n",
" there's\n",
" a\n",
" beautiful\n",
" or\n",
" memorable\n",
" tale\n",
" connected\n",
" to\n",
" the\n",
" richness\n",
" of\n",
" your\n",
" journey\n",
".\n",
" What\n",
" specific\n",
" aspects\n",
" of\n",
" your\n",
" time\n",
" in\n",
" Bali\n",
" would\n",
" you\n",
" like\n",
" there\n",
"!\n",
" \n",
" Feel\n",
" free\n",
" to\n",
" hear\n",
" a\n",
" long\n",
" story\n",
" share\n",
" more\n",
" about\n",
"?\n",
" your\n",
" situation\n",
" if\n",
" you'd\n",
" like\n",
" -\n",
" I'm\n",
" here\n",
" to\n",
" listen\n",
"!\n",
"\",\"\n",
"sources\n",
"\":\"\n",
"User\n",
" context\n",
" stating\n",
" they\n",
" are\n",
" in\n",
" Bali\n",
".\"\n",
"\":\n",
"\"\"\n",
"}\n",
"<graphai:end:llm_node>\n",
"<graphai:END>\n"
Expand Down
42 changes: 29 additions & 13 deletions graphai/graph.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import List
from typing import List, Dict, Any
from graphai.nodes.base import _Node
from graphai.callback import Callback
from semantic_router.utils.logger import logger


class Graph:
def __init__(self):
def __init__(self, max_steps: int = 10):
self.nodes = []
self.edges = []
self.start_node = None
self.end_nodes = []
self.Callback = Callback
self.callback = None
self.max_steps = max_steps

def add_node(self, node):
self.nodes.append(node)
Expand Down Expand Up @@ -52,43 +53,58 @@ def compile(self):
raise Exception("No end nodes defined.")
if not self._is_valid():
raise Exception("Graph is not valid.")
print("Graph compiled successfully.")

def _is_valid(self):
# Implement validation logic, e.g., checking for cycles, disconnected components, etc.
return True

def _validate_output(self, output: Dict[str, Any], node_name: str):
if not isinstance(output, dict):
raise ValueError(
f"Expected dictionary output from node {node_name}. "
f"Instead, got {type(output)} from '{output}'."
)

async def execute(self, input):
# TODO JB: may need to add init callback here to init the queue on every new execution
if self.callback is None:
self.callback = self.get_callback()
current_node = self.start_node
state = input
steps = 0
while True:
# we invoke the node here
if current_node.stream:
if self.callback is None:
# TODO JB: can remove?
raise ValueError("No callback provided to graph. Please add it via `.add_callback`.")
# add callback tokens and param here if we are streaming
await self.callback.start_node(node_name=current_node.name)
state = await current_node.invoke(input=state, callback=self.callback)
output = await current_node.invoke(input=state, callback=self.callback)
self._validate_output(output=output, node_name=current_node.name)
await self.callback.end_node(node_name=current_node.name)
else:
state = await current_node.invoke(input=state)
output = await current_node.invoke(input=state)
self._validate_output(output=output, node_name=current_node.name)
if current_node.is_router:
# if we have a router node we let the router decide the next node
next_node_name = str(state["choice"])
del state["choice"]
current_node = self._get_node_by_name(next_node_name)
next_node_name = str(output["choice"])
del output["choice"]
current_node = self._get_node_by_name(node_name=next_node_name)
else:
# otherwise, we have linear path
current_node = self._get_next_node(current_node, state)
current_node = self._get_next_node(current_node=current_node)
# add output to state
state = {**state, **output}
if current_node.is_end:
break
steps += 1
if steps >= self.max_steps:
raise Exception(
f"Max steps reached: {self.max_steps}. You can modify this "
"by setting `max_steps` when initializing the Graph object."
)
# TODO JB: may need to add end callback here to close the queue for every execution
if self.callback:
await self.callback.close()
del state["callback"]
return state

def get_callback(self):
Expand All @@ -101,7 +117,7 @@ def _get_node_by_name(self, node_name: str) -> _Node:
return node
raise Exception(f"Node with name {node_name} not found.")

def _get_next_node(self, current_node, output):
def _get_next_node(self, current_node):
for edge in self.edges:
if edge.source == current_node:
return edge.destination
Expand Down
Loading

0 comments on commit bbfb89d

Please sign in to comment.