Skip to content

Commit 56e85c6

Browse files
committed
feat: enable multiagent session persistent
1 parent 95906fa commit 56e85c6

File tree

7 files changed

+644
-24
lines changed

7 files changed

+644
-24
lines changed

src/strands/multiagent/graph.py

Lines changed: 174 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,22 @@
2626
from .._async import run_async
2727
from ..agent import Agent
2828
from ..agent.state import AgentState
29+
from ..experimental.hooks.multiagent import (
30+
AfterMultiAgentInvocationEvent,
31+
AfterNodeCallEvent,
32+
MultiAgentInitializedEvent,
33+
)
34+
from ..hooks import HookProvider, HookRegistry
35+
from ..session import SessionManager
2936
from ..telemetry import get_tracer
3037
from ..types.content import ContentBlock, Messages
3138
from ..types.event_loop import Metrics, Usage
3239
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
3340

3441
logger = logging.getLogger(__name__)
3542

43+
_DEFAULT_GRAPH_ID = "default_graph"
44+
3645

3746
@dataclass
3847
class GraphState:
@@ -216,6 +225,9 @@ def __init__(self) -> None:
216225
self._execution_timeout: Optional[float] = None
217226
self._node_timeout: Optional[float] = None
218227
self._reset_on_revisit: bool = False
228+
self._id: str = _DEFAULT_GRAPH_ID
229+
self._session_manager: Optional[SessionManager] = None
230+
self._hooks: Optional[list[HookProvider]] = None
219231

220232
def add_node(self, executor: Agent | MultiAgentBase, node_id: str | None = None) -> GraphNode:
221233
"""Add an Agent or MultiAgentBase instance as a node to the graph."""
@@ -306,6 +318,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
306318
self._node_timeout = timeout
307319
return self
308320

321+
def set_graph_id(self, graph_id: str) -> "GraphBuilder":
322+
"""Set graph id.
323+
324+
Args:
325+
graph_id: Unique graph id default to uuid4
326+
"""
327+
self._id = graph_id
328+
return self
329+
330+
def set_session_manager(self, session_manager: SessionManager) -> "GraphBuilder":
331+
"""Set session manager for the graph.
332+
333+
Args:
334+
session_manager: SessionManager instance
335+
"""
336+
self._session_manager = session_manager
337+
return self
338+
339+
def set_hook_providers(self, hooks: list[HookProvider]) -> "GraphBuilder":
340+
"""Set hook providers for the graph.
341+
342+
Args:
343+
hooks: Customer hooks user passes in
344+
"""
345+
self._hooks = hooks
346+
return self
347+
309348
def build(self) -> "Graph":
310349
"""Build and validate the graph with configured settings."""
311350
if not self.nodes:
@@ -324,13 +363,16 @@ def build(self) -> "Graph":
324363
self._validate_graph()
325364

326365
return Graph(
366+
id=self._id,
327367
nodes=self.nodes.copy(),
328368
edges=self.edges.copy(),
329369
entry_points=self.entry_points.copy(),
330370
max_node_executions=self._max_node_executions,
331371
execution_timeout=self._execution_timeout,
332372
node_timeout=self._node_timeout,
333373
reset_on_revisit=self._reset_on_revisit,
374+
session_manager=self._session_manager,
375+
hooks=self._hooks,
334376
)
335377

336378
def _validate_graph(self) -> None:
@@ -358,6 +400,10 @@ def __init__(
358400
execution_timeout: Optional[float] = None,
359401
node_timeout: Optional[float] = None,
360402
reset_on_revisit: bool = False,
403+
session_manager: Optional[SessionManager] = None,
404+
hooks: Optional[list[HookProvider]] = None,
405+
*,
406+
id: Optional[str] = None,
361407
) -> None:
362408
"""Initialize Graph with execution limits and reset behavior.
363409
@@ -369,11 +415,15 @@ def __init__(
369415
execution_timeout: Total execution timeout in seconds (default: None - no limit)
370416
node_timeout: Individual node timeout in seconds (default: None - no limit)
371417
reset_on_revisit: Whether to reset node state when revisited (default: False)
418+
session_manager: Session manager for persisting graph state and execution history (default: None)
419+
hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
420+
id: Unique graph id (default: None)
372421
"""
373422
super().__init__()
374423

375424
# Validate nodes for duplicate instances
376425
self._validate_graph(nodes)
426+
self.id = id or _DEFAULT_GRAPH_ID
377427

378428
self.nodes = nodes
379429
self.edges = edges
@@ -384,6 +434,18 @@ def __init__(
384434
self.reset_on_revisit = reset_on_revisit
385435
self.state = GraphState()
386436
self.tracer = get_tracer()
437+
self.session_manager = session_manager
438+
self.hooks = HookRegistry()
439+
if self.session_manager:
440+
self.hooks.add_hook(self.session_manager)
441+
if hooks:
442+
for hook in hooks:
443+
self.hooks.add_hook(hook)
444+
445+
self._resume_next_nodes: list[GraphNode] = []
446+
self._resume_from_session = False
447+
448+
self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self))
387449

388450
def __call__(
389451
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
@@ -418,16 +480,20 @@ async def invoke_async(
418480

419481
logger.debug("task=<%s> | starting graph execution", task)
420482

421-
# Initialize state
422483
start_time = time.time()
423-
self.state = GraphState(
424-
status=Status.EXECUTING,
425-
task=task,
426-
total_nodes=len(self.nodes),
427-
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
428-
entry_points=list(self.entry_points),
429-
start_time=start_time,
430-
)
484+
if not self._resume_from_session:
485+
# Initialize state
486+
self.state = GraphState(
487+
status=Status.EXECUTING,
488+
task=task,
489+
total_nodes=len(self.nodes),
490+
edges=[(edge.from_node, edge.to_node) for edge in self.edges],
491+
entry_points=list(self.entry_points),
492+
start_time=start_time,
493+
)
494+
else:
495+
self.state.status = Status.EXECUTING
496+
self.state.start_time = start_time
431497

432498
span = self.tracer.start_multiagent_span(task, "graph")
433499
with trace_api.use_span(span, end_on_exit=True):
@@ -455,6 +521,9 @@ async def invoke_async(
455521
raise
456522
finally:
457523
self.state.execution_time = round((time.time() - start_time) * 1000)
524+
self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self))
525+
self._resume_from_session = False
526+
self._resume_next_nodes.clear()
458527
return self._build_result()
459528

460529
def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
@@ -471,7 +540,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
471540

472541
async def _execute_graph(self, invocation_state: dict[str, Any]) -> None:
473542
"""Unified execution flow with conditional routing."""
474-
ready_nodes = list(self.entry_points)
543+
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)
475544

476545
while ready_nodes:
477546
# Check execution limits before continuing
@@ -608,6 +677,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
608677
node.node_id,
609678
self.node_timeout,
610679
)
680+
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
611681
raise Exception(timeout_msg) from None
612682

613683
# Mark as completed
@@ -621,6 +691,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
621691
# Accumulate metrics
622692
self._accumulate_metrics(node_result)
623693

694+
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
695+
624696
logger.debug(
625697
"node_id=<%s>, execution_time=<%dms> | node completed successfully", node.node_id, node.execution_time
626698
)
@@ -644,6 +716,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
644716
node.execution_time = execution_time
645717
self.state.failed_nodes.add(node)
646718
self.state.results[node.node_id] = node_result # Store in results for consistency
719+
self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state))
647720

648721
raise
649722

@@ -731,3 +804,94 @@ def _build_result(self) -> GraphResult:
731804
edges=self.state.edges,
732805
entry_points=self.state.entry_points,
733806
)
807+
808+
def serialize_state(self) -> dict[str, Any]:
809+
"""Serialize the current graph state to a dictionary."""
810+
status_str = self.state.status.value
811+
compute_nodes = self._compute_ready_nodes_for_resume()
812+
next_nodes = [n.node_id for n in compute_nodes] if compute_nodes else []
813+
return {
814+
"type": "graph",
815+
"id": self.id,
816+
"status": status_str,
817+
"completed_nodes": [n.node_id for n in self.state.completed_nodes],
818+
"failed_nodes": [n.node_id for n in self.state.failed_nodes],
819+
"node_results": {k: v.to_dict() for k, v in (self.state.results or {}).items()},
820+
"next_nodes_to_execute": next_nodes,
821+
"current_task": self.state.task,
822+
"execution_order": [n.node_id for n in self.state.execution_order],
823+
}
824+
825+
def deserialize_state(self, payload: dict[str, Any]) -> None:
826+
"""Restore graph state from a session dict and prepare for execution.
827+
828+
This method handles two scenarios:
829+
1. If the persisted status is COMPLETED, FAILED resets all nodes and graph state
830+
to allow re-execution from the beginning.
831+
2. Otherwise, restores the persisted state and prepares to resume execution
832+
from the next ready nodes.
833+
834+
Args:
835+
payload: Dictionary containing persisted state data including status,
836+
completed nodes, results, and next nodes to execute.
837+
"""
838+
if not payload.get("next_nodes_to_execute"):
839+
# Reset all nodes
840+
for node in self.nodes.values():
841+
node.reset_executor_state()
842+
# Reset graph state
843+
self.state = GraphState()
844+
self._resume_from_session = False
845+
return
846+
else:
847+
self._from_dict(payload)
848+
self._resume_from_session = True
849+
850+
# Helper functions for serialize and deserialize
851+
def _compute_ready_nodes_for_resume(self) -> list[GraphNode]:
852+
if self.state.status == Status.PENDING:
853+
return []
854+
ready_nodes: list[GraphNode] = []
855+
completed_nodes = set(self.state.completed_nodes)
856+
for node in self.nodes.values():
857+
if node in completed_nodes:
858+
continue
859+
incoming = [e for e in self.edges if e.to_node is node]
860+
if not incoming:
861+
ready_nodes.append(node)
862+
elif all(e.from_node in completed_nodes and e.should_traverse(self.state) for e in incoming):
863+
ready_nodes.append(node)
864+
865+
return ready_nodes
866+
867+
def _from_dict(self, payload: dict[str, Any]) -> None:
868+
self.state.status = Status(payload["status"])
869+
# Hydrate completed nodes & results
870+
raw_results = payload.get("node_results") or {}
871+
results: dict[str, NodeResult] = {}
872+
for node_id, entry in raw_results.items():
873+
if node_id not in self.nodes:
874+
continue
875+
try:
876+
results[node_id] = NodeResult.from_dict(entry)
877+
except Exception:
878+
logger.exception("Failed to hydrate NodeResult for node_id=%s; skipping.", node_id)
879+
raise
880+
self.state.results = results
881+
882+
self.state.failed_nodes = set(payload.get("failed_nodes") or [])
883+
884+
# Restore completed nodes from persisted data
885+
completed_node_ids = payload.get("completed_nodes") or []
886+
self.state.completed_nodes = {self.nodes[node_id] for node_id in completed_node_ids if node_id in self.nodes}
887+
888+
# Execution order (only nodes that still exist)
889+
order_node_ids = payload.get("execution_order") or []
890+
self.state.execution_order = [self.nodes[node_id] for node_id in order_node_ids if node_id in self.nodes]
891+
892+
# Task
893+
self.state.task = payload.get("current_task", self.state.task)
894+
895+
# next nodes to execute
896+
next_nodes = [self.nodes[nid] for nid in (payload.get("next_nodes_to_execute") or []) if nid in self.nodes]
897+
self._resume_next_nodes = next_nodes

0 commit comments

Comments
 (0)