From b0ad380f7eb333c6d9dff6a8687d50c7c6ca31ab Mon Sep 17 00:00:00 2001 From: James Briggs Date: Thu, 25 Jul 2024 21:49:58 +0800 Subject: [PATCH] feat: working initial graph --- docs/00-getting-started.ipynb | 58 +++++++++++++++++++++++++-------- graphai/__init__.py | 3 +- graphai/graph.py | 61 +++++++++++++++++++++++++++++++++++ graphai/nodes/base.py | 48 +++++++++++++-------------- 4 files changed, 130 insertions(+), 40 deletions(-) create mode 100644 graphai/graph.py diff --git a/docs/00-getting-started.ipynb b/docs/00-getting-started.ipynb index 4a36cde..8132786 100644 --- a/docs/00-getting-started.ipynb +++ b/docs/00-getting-started.ipynb @@ -8,11 +8,18 @@ "source": [ "from graphai import node\n", "\n", - "@node\n", - "def my_node(param_a: str, param_b: str = \"hi\"):\n", + "@node(start=True)\n", + "def node_a(param_a: str):\n", " \"\"\"Descriptive string for the node.\n", " \"\"\"\n", - " return \"Hello, World!\"" + " return {\"param_b\": \"Hello, World!\"}\n", + "\n", + "@node(end=True)\n", + "def node_b(param_b: str):\n", + " \"\"\"Descriptive string for the node.\n", + " \"\"\"\n", + " new_str = param_b*2\n", + " return {\"output\": new_str}" ] }, { @@ -23,7 +30,7 @@ { "data": { "text/plain": [ - "graphai.nodes.base._Node._node..NodeClass" + "\"param_a: \"" ] }, "execution_count": 2, @@ -32,7 +39,7 @@ } ], "source": [ - "my_node" + "node_a.get_signature()" ] }, { @@ -43,7 +50,7 @@ { "data": { "text/plain": [ - ".NodeClass at 0x1047fedd0>" + "{'param_b': 'Hello, World!'}" ] }, "execution_count": 3, @@ -52,47 +59,70 @@ } ], "source": [ - "my_node(param_a=\"l\")" + "node_a.invoke(input={\"param_a\": \"l\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from graphai import Graph\n", + "\n", + "graph = Graph()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, + "outputs": [], + "source": [ + "graph.add_node(node_a)\n", + "graph.add_node(node_b)\n", + "graph.add_edge(source=node_a, destination=node_b)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "(graphai.nodes.base._Node._node..NodeClass,\n", + " [graphai.nodes.base._Node._node..NodeClass])" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "my_node._func_signature" + "graph.start_node, graph.end_nodes" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'Hello, World!'" + "{'param_b': 'Hello, World!'}" ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "my_node.invoke(input={\"param_a\": \"l\"})" + "graph.execute(input={\"param_a\": \"l\"})" ] }, { diff --git a/graphai/__init__.py b/graphai/__init__.py index 5a0aa20..3c415b9 100644 --- a/graphai/__init__.py +++ b/graphai/__init__.py @@ -1,3 +1,4 @@ +from graphai.graph import Graph from graphai.nodes import node -__all__ = ["node"] \ No newline at end of file +__all__ = ["node", "Graph"] \ No newline at end of file diff --git a/graphai/graph.py b/graphai/graph.py new file mode 100644 index 0000000..5304abd --- /dev/null +++ b/graphai/graph.py @@ -0,0 +1,61 @@ +class Graph: + def __init__(self): + self.nodes = [] + self.edges = [] + self.start_node = None + self.end_nodes = [] + + def add_node(self, node): + self.nodes.append(node) + if node.is_start: + if self.start_node is not None: + raise Exception("Multiple start nodes are not allowed.") + self.start_node = node + if node.is_end: + self.end_nodes.append(node) + + def add_edge(self, source, destination): + # TODO add logic to check that source and destination are nodes + # and they exist in the graph object already + edge = Edge(source, destination) + self.edges.append(edge) + + def set_start_node(self, node): + self.start_node = node + + def set_end_node(self, node): + self.end_node = node + + def compile(self): + if not self.start_node: + raise Exception("Start node not defined.") + if not self.end_nodes: + raise Exception("No end nodes defined.") + if not self._is_valid(): + raise Exception("Graph is not valid.") + print("Graph compiled successfully.") + + def _is_valid(self): + # Implement validation logic, e.g., checking for cycles, disconnected components, etc. + return True + + def execute(self, input): + current_node = self.start_node + while current_node not in self.end_nodes: + output = current_node.invoke(input=input) + current_node = self._get_next_node(current_node, output) + if current_node.is_end: + break + return output + + def _get_next_node(self, current_node, output): + for edge in self.edges: + if edge.source == current_node: + return edge.destination + raise Exception("No outgoing edge found for current node.") + + +class Edge: + def __init__(self, source, destination): + self.source = source + self.destination = destination \ No newline at end of file diff --git a/graphai/nodes/base.py b/graphai/nodes/base.py index 06236b6..008ddda 100644 --- a/graphai/nodes/base.py +++ b/graphai/nodes/base.py @@ -17,9 +17,9 @@ def __call__(cls, *args, **kwargs): class _Node: def __init__(self): - self._func_signature = None + pass - def _node(self, func: Callable) -> Callable: + def _node(self, func: Callable, start: bool = False, end: bool = False) -> Callable: """Decorator validating node structure. """ if not callable(func): @@ -42,6 +42,22 @@ def execute(self): bound_args.apply_defaults() return func(*bound_args.args, **bound_args.kwargs) + @classmethod + def get_signature(cls): + """Returns the signature of the decorated function as LLM readable + string. + """ + signature_components = [] + if NodeClass._func_signature: + for param in NodeClass._func_signature.parameters.values(): + if param.default is param.empty: + signature_components.append(f"{param.name}: {param.annotation}") + else: + signature_components.append(f"{param.name}: {param.annotation} = {param.default}") + else: + return "No signature" + return "\n".join(signature_components) + @classmethod def invoke(cls, input: Dict[str, Any]): instance = cls(**input) @@ -49,39 +65,21 @@ def invoke(cls, input: Dict[str, Any]): NodeClass.__name__ = func.__name__ NodeClass.__doc__ = func.__doc__ + NodeClass.is_start = start + NodeClass.is_end = end return NodeClass - def __call__(self, func: Optional[Callable] = None): + def __call__(self, func: Optional[Callable] = None, start: bool = False, end: bool = False): # We must wrap the call to the decorator in a function for it to work # correctly with or without parenthesis def wrap(func: Callable) -> Callable: - return self._node(func) + return self._node(func=func, start=start, end=end) if func: # Decorator is called without parenthesis - return wrap(func) + return wrap(func=func, start=start, end=end) # Decorator is called with parenthesis return wrap - - def _get_signature(self) -> inspect.Signature: - """Returns the signature of the decorated function. - """ - return self._func_signature - - def signature(self): - """Returns the signature of the decorated function as LLM readable - string. - """ - signature_str = "" - if self._func_signature: - for param in self._func_signature.parameters.values(): - if param.default is param.empty: - signature_str += f"{param.name}: {param.annotation}" - else: - signature_str += f"{param.name}: {param.annotation} = {param.default}" - else: - return "No signature" - return signature_str node = _Node()