Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/gallery/autogen/zone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
==================
Zone tasks
==================

`node-graph` supports a context-manager style for building graphs inline.
The context-manager style also exposes zones such as ``If`` and ``While``
and a shared ``ctx`` object for passing values between tasks.
"""


# %%
# Zones: If and While
# -------------------
# Zones allow you to gate or repeat a block of tasks based on a condition.
# The condition itself is defined by a task result.
#
# ``ctx`` is a shared store for intermediate values. You can write to it
# from one task and read it as input in another task.

from node_graph import task
from node_graph import If, While, get_current_graph


@task()
def smaller_than(x, y):
return x < y


@task()
def add(x, y):
return x + y


@task()
def is_even(x):
return x % 2 == 0


@task.graph()
def while_with_if(index=0, limit=10, total=0, increment=1):
graph = get_current_graph()
graph.ctx.total = total
graph.ctx.index = index
condition = smaller_than(graph.ctx.index, limit).result

with While(condition):
is_even_cond = is_even(graph.ctx.index).result
with If(is_even_cond) as if_zone:
graph.ctx.total = add(
x=graph.ctx.total,
y=graph.ctx.index,
).result
next_index = add(x=graph.ctx.index, y=increment).result
graph.ctx.index = next_index
if_zone >> next_index

return graph.ctx.total


g = while_with_if.build(index=0, limit=10, total=0, increment=1)
g.to_html()

# %%
# If and While are available for both context-manager graphs and ``@task.graph``
# workflows.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Then spin up your first graph:
autogen/quick_start
autogen/annotate_inputs_outputs
autogen/annotate_semantics
autogen/zone
concept/index
yaml
customize
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ tests = [
"node_graph.test_add" = "node_graph.tasks.tests:test_add"
"node_graph.test_enum" = "node_graph.tasks.tests:TestEnum"
"node_graph.test_enum_update" = "node_graph.tasks.tests:TestEnumUpdate"
"node_graph.while_zone" = "node_graph.tasks.builtins:While"
"node_graph.zone" = "node_graph.tasks.builtins:Zone"
"node_graph.if_zone" = "node_graph.tasks.builtins:If"

[project.entry-points."node_graph.socket"]
"node_graph.any" = "node_graph.sockets.builtins:SocketAny"
Expand Down Expand Up @@ -97,10 +100,10 @@ tests = [
"node_graph.base_list" = "node_graph.properties.builtins:PropertyBaseList"

[project.entry-points."node_graph.type_mapping"]
"workgraph.builtins_mapping" = "node_graph.orm.mapping:type_mapping"
"node_graph.builtins_mapping" = "node_graph.orm.mapping:type_mapping"

[project.entry-points."node_graph.type_promotion"]
"workgraph.builtins_mapping" = "node_graph.link:TYPE_PROMOTIONS"
"node_graph.builtins_mapping" = "node_graph.link:TYPE_PROMOTIONS"

[project.scripts]
node-graph = "node_graph.cli:cli"
4 changes: 4 additions & 0 deletions src/node_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .tasks import TaskPool
from .collection import group
from .socket_spec import namespace, dynamic
from .manager import get_current_graph, While, If

__version__ = "0.5.1"

Expand All @@ -20,5 +21,8 @@
"group",
"namespace",
"dynamic",
"get_current_graph",
"While",
"If",
"KnowledgeGraph",
]
5 changes: 1 addition & 4 deletions src/node_graph/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def build_zone(self) -> Dict[str, Dict[str, List[str]]]:
}

*External inputs* are compressed to the **closest ancestor** outside
the zone – matching the logic in ``WorkGraphSaver.find_zone_inputs``.
the zone.
"""
self._ensure_cache_valid()
return self._cache_zone
Expand Down Expand Up @@ -303,15 +303,13 @@ def _compute_zone_cache(self):
"""
Populate ``_cache_zone`` for every task (each task is a "zone").

This method faithfully implements the WorkGraphSaver logic:

1) **Direct child / parent mapping**
- Build `direct_children[task]` from `task.children`.
- Build `parent_of[child] = parent` for each child.

2) **Parent chains**
- `parent_chain(name)` returns a list [`parent`, ..., None]
matching WorkGraphSaver.update_parent_task.

3) **Recursive zone input discovery** via `zone_inputs(zone_name)`:

Expand Down Expand Up @@ -352,7 +350,6 @@ def _compute_zone_cache(self):
for k in kids:
parent_of[k] = parent

# helper – parent chain (like WorkGraphSaver.update_parent_task)
parent_chain_cache: Dict[str, List[Optional[str]]] = {}

def parent_chain(name: str) -> List[Optional[str]]:
Expand Down
117 changes: 95 additions & 22 deletions src/node_graph/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
"""
from contextlib import contextmanager
from contextvars import ContextVar
from node_graph.tasks.task_pool import TaskPool
from node_graph.socket import TaskSocket

_current_graph: ContextVar["Graph | None"] = ContextVar("current_graph", default=None)


class CurrentGraphManager:
Expand All @@ -20,70 +24,139 @@ def __new__(cls, *args, **kwargs):
"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._graph = None # Storage for the active graph
return cls._instance

def peek_current_graph(self):
"""Return the active graph or None (do NOT auto-create)."""
return self._graph
return _current_graph.get()

def get_current_graph(self):
"""
Retrieve the current graph, or create a new one if none is set.
"""
from node_graph.graph import Graph

if self._graph is None:
self._graph = Graph()
return self._graph
g = _current_graph.get()
if g is None:
g = Graph()
_current_graph.set(g)
return g

def set_current_graph(self, graph):
"""
Set the active graph to the given instance.
"""
self._graph = graph
_current_graph.set(graph)

@contextmanager
def active_graph(self, graph):
"""
Context manager that temporarily overrides the current graph
with `graph`, restoring the old graph when exiting the context.
"""
old_graph = self._graph
self._graph = graph
token = _current_graph.set(graph)
try:
yield graph
finally:
self._graph = old_graph
_current_graph.reset(token)


# Create a global manager instance
_manager = CurrentGraphManager()
_current_graph: ContextVar["Graph | None"] = ContextVar("current_graph", default=None)


def peek_current_graph():
return _current_graph.get()
return _manager.peek_current_graph()


def get_current_graph():
from node_graph.graph import Graph

g = _current_graph.get()
if g is None:
g = Graph() # fallback to a default core graph
_current_graph.set(g)
return g
return _manager.get_current_graph()


def set_current_graph(graph):
_current_graph.set(graph)
_manager.set_current_graph(graph)


@contextmanager
def active_graph(graph):
token = _current_graph.set(graph)
with _manager.active_graph(graph) as ctx:
yield ctx


@contextmanager
def Zone():
"""
Context manager to create a "zone" in the current graph.
"""

graph = get_current_graph()

zone_task = graph.add_task(
TaskPool.node_graph.zone,
)

old_zone = getattr(graph, "_active_zone", None)
if old_zone:
old_zone.children.add(zone_task)
graph._active_zone = zone_task

try:
yield zone_task
finally:
graph._active_zone = old_zone


@contextmanager
def If(condition_socket: TaskSocket, invert_condition: bool = False):
"""
Context manager to create a "conditional zone" in the current graph.

:param condition_socket: A TaskSocket or boolean-like object (e.g. sum_ > 0)
:param invert_condition: Whether to invert the condition (useful for else-zones)
"""

graph = get_current_graph()

zone_task = graph.add_task(
TaskPool.node_graph.if_zone,
conditions=condition_socket,
invert_condition=invert_condition,
)

old_zone = getattr(graph, "_active_zone", None)
if old_zone:
old_zone.children.add(zone_task)
graph._active_zone = zone_task

try:
yield zone_task
finally:
graph._active_zone = old_zone


@contextmanager
def While(condition_socket: TaskSocket, max_iterations: int = 10000):
"""
Context manager to create a "while zone" in the current graph.

:param condition_socket: A TaskSocket or boolean-like object (e.g. sum_ > 0)
:param max_iterations: Maximum number of iterations before breaking the loop
"""

graph = get_current_graph()

zone_task = graph.add_task(
TaskPool.node_graph.while_zone,
conditions=condition_socket,
max_iterations=max_iterations,
)

old_zone = getattr(graph, "_active_zone", None)
if old_zone:
old_zone.children.add(zone_task)
graph._active_zone = zone_task

try:
yield graph
yield zone_task
finally:
_current_graph.reset(token)
graph._active_zone = old_zone
6 changes: 3 additions & 3 deletions src/node_graph/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _raise_illegal(sock, what: str, tips: list[str]):
common = [
"General guidance:",
" • Wrap logic in a nested @task.graph.",
" • Or use the WorkGraph If zone for branching on predicates.",
" • Or use the Graph If zone for branching on predicates.",
" • Or for loops, use the While zone or Map zone.",
]

Expand Down Expand Up @@ -155,14 +155,14 @@ def _decorator(self):
return task

def _create_operator_task(self, op_func, x, y):
"""Create a "hidden" operator Task in the WorkGraph,
"""Create a "hidden" operator Task in the Graph,
hooking `self` up as 'x' and `other` as 'y'.
Return the output socket from that new Task.
"""

graph = self._task.graph
if not graph:
raise ValueError("Socket does not belong to a WorkGraph.")
raise ValueError("Socket does not belong to a Graph.")

new_node = graph.tasks._new(
self._decorator()(op_func),
Expand Down
Loading