Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/node_graph/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ def _graph_runner(**kwargs):
sub_ng = self._build_subgraph(task, graph_fn, kwargs)
parent_pid = self._get_active_graph_pid()
self._run_subgraph(task, sub_ng, parent_pid)
return sub_ng.outputs._collect_values(raw=False)
return sub_ng.outputs._collect_values(unwrap=False)

return _graph_runner

@staticmethod
def _snapshot_builtins(ng: Graph) -> Dict[str, Dict[str, Any]]:
return {
"graph_ctx": ng.ctx._collect_values(raw=False),
"graph_inputs": ng.inputs._collect_values(raw=False),
"graph_outputs": ng.outputs._collect_values(raw=False),
"graph_ctx": ng.ctx._collect_values(unwrap=False),
"graph_inputs": ng.inputs._collect_values(unwrap=False),
"graph_outputs": ng.outputs._collect_values(unwrap=False),
}

def _graph_flow_run_id(self, ng: Graph) -> str:
Expand All @@ -88,7 +88,7 @@ def _start_graph_run(
parent_pid=parent_pid,
)
self.recorder.record_inputs_payload(
graph_pid, ng.inputs._collect_values(raw=False)
graph_pid, ng.inputs._collect_values(unwrap=False)
)
return graph_pid

Expand Down Expand Up @@ -130,7 +130,7 @@ def _normalize_outputs(
f"Failed to parse outputs for task '{task.name}': {e}"
) from e
tag_socket_value(task.outputs, only_uuid=True)
return task.outputs._collect_values(raw=False)
return task.outputs._collect_values(unwrap=False)

def _link_socket_value(
self, from_name: str, from_socket: str, source_map: Dict[str, Any]
Expand Down
9 changes: 7 additions & 2 deletions src/node_graph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ def update_nested_dict_with_special_keys(data: Dict[str, Any]) -> Dict[str, Any]
return data


def _collect_literals(task, raw=False) -> Dict[str, Any]:
def _collect_literals(task, unwrap=False) -> Dict[str, Any]:
"""
Recursively collect literal values from the task's input namespace, excluding
values that are overridden by links at schedule time.
"""
from node_graph.utils import tag_socket_value

tag_socket_value(task.inputs, only_uuid=True)
return task.inputs._collect_values(raw=raw)
return task.inputs._collect_values(unwrap=unwrap)


def _resolve_tagged_value(value: Any) -> Any:
Expand Down Expand Up @@ -210,6 +210,11 @@ def _build_task_link_kwargs(

kwargs: Dict[str, Any] = {}
for to_sock, lks in grouped.items():
if any(
lk.to_socket._metadata.extras.get("value_source") == "property"
for lk in lks
):
continue
# ignore _wait edges for value propagation (handled separately)
active_links = [lk for lk in lks if lk.from_socket._scoped_name != "_wait"]
if not active_links:
Expand Down
112 changes: 91 additions & 21 deletions src/node_graph/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from node_graph.collection import get_item_class
from dataclasses import MISSING, replace
from node_graph.orm.mapping import type_mapping
from node_graph.socket_meta import SocketMeta
from node_graph.socket_meta import SocketMeta, UPDATABLE_SOCKET_META_FIELDS
from node_graph.registry import EntryPointPool
import wrapt

Expand Down Expand Up @@ -591,6 +591,17 @@ def _to_dict(self) -> Dict[str, Any]:
data["deserialize"] = self.get_deserialize()
return data

def _update_updatable_meta(self, payload: Dict[str, Any]) -> None:
"""Update whitelisted metadata extras on the socket."""
extras = self._metadata.extras
for key, value in payload.items():
if key not in UPDATABLE_SOCKET_META_FIELDS:
continue
if value is None or (key == "value_source" and value == "link"):
extras.pop(key, None)
else:
extras[key] = value


class TaskSocket(BaseSocket, OperatorSocketMixin):
_identifier: str = "TaskSocket"
Expand Down Expand Up @@ -653,34 +664,53 @@ def _value(self) -> Any:

@property
def value(self) -> Any:
if (
self._task is not None
and self._task._input_resolver is not None
and self._full_name.split(".")[0] == "inputs"
and self._metadata.extras.get("value_source") != "property"
):
return self._task._input_resolver(self)
return self._value

@value.setter
def value(self, value: Any) -> None:
self._set_socket_value(value)

def _set_socket_value(self, value: Any) -> None:
def _set_socket_value(self, value: Any, *, value_source: str = "link") -> None:
if isinstance(value, BaseSocket):
if (
isinstance(value, TaskSocketNamespace)
and value._parent is None
and "_outputs" in value
):
value = value._outputs
self._update_updatable_meta({"value_source": "link"})
self._task.graph.add_link(value, self)
elif isinstance(value, TaggedValue) and value._socket is not None:
self._update_updatable_meta({"value_source": "link"})
self._task.graph.add_link(value._socket, self)
elif self.property:
if self._full_name.split(".")[0] == "inputs" and any(
is_input = self._full_name.split(".")[0] == "inputs"
has_link = any(
[
link.to_socket._full_name_with_task == self._full_name_with_task
for link in self._links
]
):
raise ValueError(
f"Input {self._full_name_with_task} has already been set via a "
f"link. Please update the linked value in {self._links}"
)
if is_input and has_link:
override = (
value_source == "property"
or self._metadata.extras.get("value_source") == "property"
)
if not override:
raise ValueError(
f"Input {self._full_name_with_task} has already been set via a "
f"link. Please update the linked value in {self._links}"
)
self._update_updatable_meta({"value_source": "property"})
elif is_input and value_source != "property":
self._update_updatable_meta({"value_source": "link"})

self.property.value = value
else:
Expand Down Expand Up @@ -999,23 +1029,22 @@ def _new(
def _value(self) -> Dict[str, Any]:
return self._collect_values()

def _collect_values(self, raw: bool = True) -> Dict[str, Any]:
def _collect_values(
self, unwrap: bool = True, resolve: bool = False
) -> Dict[str, Any]:
data = {}
for name, item in self._sockets.items():
if isinstance(item, TaskSocketNamespace):
value = item._collect_values(raw=raw)
value = item._collect_values(unwrap=unwrap, resolve=resolve)
if value:
data[name] = value
else:
if item.value is not None:
if raw:
data[name] = (
item.value.__wrapped__
if isinstance(item.value, TaggedValue)
else item.value
)
value = item.value if resolve else item._value
if value is not None:
if unwrap and isinstance(value, TaggedValue):
data[name] = value.__wrapped__
else:
data[name] = item.value
data[name] = value
return data

def _to_spec(self) -> "SocketSpec":
Expand Down Expand Up @@ -1112,7 +1141,9 @@ def _spec_from_shape_snapshot(snapshot: Dict[str, Any]) -> "SocketSpec":
def _value(self, value: Dict[str, Any]) -> None:
self._set_socket_value(value)

def _set_socket_value(self, value: Dict[str, Any] | TaskSocket) -> None:
def _set_socket_value(
self, value: Dict[str, Any] | TaskSocket, *, value_source: str = "link"
) -> None:
"""Set value(s) into this namespace.

Supports:
Expand All @@ -1131,6 +1162,7 @@ def _set_socket_value(self, value: Dict[str, Any] | TaskSocket) -> None:

# Link another socket directly to this namespace
if isinstance(value, BaseSocket):
self._clear_updatable_meta()
self._task.graph.add_link(value, self)
return

Expand Down Expand Up @@ -1207,7 +1239,7 @@ def _set_socket_value(self, value: Dict[str, Any] | TaskSocket) -> None:
)

# Recurse into the child namespace with the remaining tail
child._set_socket_value({tail: val})
child._set_socket_value({tail: val}, value_source=value_source)
continue # next key

# Non-dotted key path (single-segment)
Expand Down Expand Up @@ -1253,7 +1285,7 @@ def _set_socket_value(self, value: Dict[str, Any] | TaskSocket) -> None:
if isinstance(target, TaskSocketNamespace):
# If incoming val is a dict, recurse. If it’s a socket, link to the namespace.
if isinstance(val, dict):
target._set_socket_value(val)
target._set_socket_value(val, value_source=value_source)
elif isinstance(val, BaseSocket):
self._task.graph.add_link(val, target)
else:
Expand All @@ -1268,7 +1300,45 @@ def _set_socket_value(self, value: Dict[str, Any] | TaskSocket) -> None:
)
else:
# Leaf socket: forward to its own setter (which handles linking or value assignment)
target._set_socket_value(val)
target._set_socket_value(val, value_source=value_source)

def _export_updatable_meta_map(self) -> Dict[str, Dict[str, Any]]:
"""Export whitelisted metadata extras keyed by scoped socket path."""
meta_map: Dict[str, Dict[str, Any]] = {}
for name, child in self._sockets.items():
child_payload = {
key: child._metadata.extras.get(key)
for key in UPDATABLE_SOCKET_META_FIELDS
if key in child._metadata.extras
}
if child_payload:
meta_map[name] = child_payload
if isinstance(child, TaskSocketNamespace):
nested = child._export_updatable_meta_map()
for nested_name, payload in nested.items():
meta_map[f"{name}.{nested_name}"] = payload
return meta_map

def _apply_updatable_meta_map(self, meta_map: Dict[str, Dict[str, Any]]) -> None:
"""Apply whitelisted metadata extras from a scoped map."""
for path, payload in meta_map.items():
try:
target = self[path]
except Exception:
continue
if hasattr(target, "_update_updatable_meta"):
target._update_updatable_meta(payload)

def _clear_updatable_meta(self) -> None:
"""Clear all whitelisted metadata extras in this namespace subtree."""
self._update_updatable_meta({key: None for key in UPDATABLE_SOCKET_META_FIELDS})
for child in self._sockets.values():
if hasattr(child, "_clear_updatable_meta"):
child._clear_updatable_meta()
elif hasattr(child, "_update_updatable_meta"):
child._update_updatable_meta(
{key: None for key in UPDATABLE_SOCKET_META_FIELDS}
)

@property
def _all_links(self) -> List["TaskLink"]:
Expand Down
2 changes: 2 additions & 0 deletions src/node_graph/socket_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from enum import Enum
from typing import Any, Dict, Mapping, Optional

UPDATABLE_SOCKET_META_FIELDS = {"value_source"}


class CallRole(str, Enum):
"""Defines how a socket's value is used in a function call."""
Expand Down
Loading