Skip to content

Commit

Permalink
feat: working initial graph
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Jul 25, 2024
1 parent 27dd7ca commit b0ad380
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 40 deletions.
58 changes: 44 additions & 14 deletions docs/00-getting-started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@
"source": [
"from graphai import node\n",
"\n",
"@node\n",
"def my_node(param_a: str, param_b: str = \"hi\"):\n",
"@node(start=True)\n",
"def node_a(param_a: str):\n",
" \"\"\"Descriptive string for the node.\n",
" \"\"\"\n",
" return \"Hello, World!\""
" return {\"param_b\": \"Hello, World!\"}\n",
"\n",
"@node(end=True)\n",
"def node_b(param_b: str):\n",
" \"\"\"Descriptive string for the node.\n",
" \"\"\"\n",
" new_str = param_b*2\n",
" return {\"output\": new_str}"
]
},
{
Expand All @@ -23,7 +30,7 @@
{
"data": {
"text/plain": [
"graphai.nodes.base._Node._node.<locals>.NodeClass"
"\"param_a: <class 'str'>\""
]
},
"execution_count": 2,
Expand All @@ -32,7 +39,7 @@
}
],
"source": [
"my_node"
"node_a.get_signature()"
]
},
{
Expand All @@ -43,7 +50,7 @@
{
"data": {
"text/plain": [
"<graphai.nodes.base._Node._node.<locals>.NodeClass at 0x1047fedd0>"
"{'param_b': 'Hello, World!'}"
]
},
"execution_count": 3,
Expand All @@ -52,47 +59,70 @@
}
],
"source": [
"my_node(param_a=\"l\")"
"node_a.invoke(input={\"param_a\": \"l\"})"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from graphai import Graph\n",
"\n",
"graph = Graph()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"graph.add_node(node_a)\n",
"graph.add_node(node_b)\n",
"graph.add_edge(source=node_a, destination=node_b)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Signature (param_a: str, param_b: str = 'hi')>"
"(graphai.nodes.base._Node._node.<locals>.NodeClass,\n",
" [graphai.nodes.base._Node._node.<locals>.NodeClass])"
]
},
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_node._func_signature"
"graph.start_node, graph.end_nodes"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Hello, World!'"
"{'param_b': 'Hello, World!'}"
]
},
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"my_node.invoke(input={\"param_a\": \"l\"})"
"graph.execute(input={\"param_a\": \"l\"})"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion graphai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from graphai.graph import Graph
from graphai.nodes import node

__all__ = ["node"]
__all__ = ["node", "Graph"]
61 changes: 61 additions & 0 deletions graphai/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
class Graph:
def __init__(self):
self.nodes = []
self.edges = []
self.start_node = None
self.end_nodes = []

def add_node(self, node):
self.nodes.append(node)
if node.is_start:
if self.start_node is not None:
raise Exception("Multiple start nodes are not allowed.")
self.start_node = node
if node.is_end:
self.end_nodes.append(node)

def add_edge(self, source, destination):
# TODO add logic to check that source and destination are nodes
# and they exist in the graph object already
edge = Edge(source, destination)
self.edges.append(edge)

def set_start_node(self, node):
self.start_node = node

def set_end_node(self, node):
self.end_node = node

def compile(self):
if not self.start_node:
raise Exception("Start node not defined.")
if not self.end_nodes:
raise Exception("No end nodes defined.")
if not self._is_valid():
raise Exception("Graph is not valid.")
print("Graph compiled successfully.")

def _is_valid(self):
# Implement validation logic, e.g., checking for cycles, disconnected components, etc.
return True

def execute(self, input):
current_node = self.start_node
while current_node not in self.end_nodes:
output = current_node.invoke(input=input)
current_node = self._get_next_node(current_node, output)
if current_node.is_end:
break
return output

def _get_next_node(self, current_node, output):
for edge in self.edges:
if edge.source == current_node:
return edge.destination
raise Exception("No outgoing edge found for current node.")


class Edge:
def __init__(self, source, destination):
self.source = source
self.destination = destination
48 changes: 23 additions & 25 deletions graphai/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __call__(cls, *args, **kwargs):

class _Node:
def __init__(self):
self._func_signature = None
pass

def _node(self, func: Callable) -> Callable:
def _node(self, func: Callable, start: bool = False, end: bool = False) -> Callable:
"""Decorator validating node structure.
"""
if not callable(func):
Expand All @@ -42,46 +42,44 @@ def execute(self):
bound_args.apply_defaults()
return func(*bound_args.args, **bound_args.kwargs)

@classmethod
def get_signature(cls):
"""Returns the signature of the decorated function as LLM readable
string.
"""
signature_components = []
if NodeClass._func_signature:
for param in NodeClass._func_signature.parameters.values():
if param.default is param.empty:
signature_components.append(f"{param.name}: {param.annotation}")
else:
signature_components.append(f"{param.name}: {param.annotation} = {param.default}")
else:
return "No signature"
return "\n".join(signature_components)

@classmethod
def invoke(cls, input: Dict[str, Any]):
instance = cls(**input)
return instance.execute()

NodeClass.__name__ = func.__name__
NodeClass.__doc__ = func.__doc__
NodeClass.is_start = start
NodeClass.is_end = end

return NodeClass

def __call__(self, func: Optional[Callable] = None):
def __call__(self, func: Optional[Callable] = None, start: bool = False, end: bool = False):
# We must wrap the call to the decorator in a function for it to work
# correctly with or without parenthesis
def wrap(func: Callable) -> Callable:
return self._node(func)
return self._node(func=func, start=start, end=end)
if func:
# Decorator is called without parenthesis
return wrap(func)
return wrap(func=func, start=start, end=end)
# Decorator is called with parenthesis
return wrap

def _get_signature(self) -> inspect.Signature:
"""Returns the signature of the decorated function.
"""
return self._func_signature

def signature(self):
"""Returns the signature of the decorated function as LLM readable
string.
"""
signature_str = ""
if self._func_signature:
for param in self._func_signature.parameters.values():
if param.default is param.empty:
signature_str += f"{param.name}: {param.annotation}"
else:
signature_str += f"{param.name}: {param.annotation} = {param.default}"
else:
return "No signature"
return signature_str


node = _Node()

0 comments on commit b0ad380

Please sign in to comment.