2626from .._async import run_async
2727from ..agent import Agent
2828from ..agent .state import AgentState
29+ from ..experimental .hooks .multiagent import (
30+ AfterMultiAgentInvocationEvent ,
31+ AfterNodeCallEvent ,
32+ MultiAgentInitializedEvent ,
33+ )
34+ from ..hooks import HookProvider , HookRegistry
35+ from ..session import SessionManager
2936from ..telemetry import get_tracer
3037from ..types .content import ContentBlock , Messages
3138from ..types .event_loop import Metrics , Usage
3239from .base import MultiAgentBase , MultiAgentResult , NodeResult , Status
3340
3441logger = logging .getLogger (__name__ )
3542
43+ _DEFAULT_GRAPH_ID = "default_graph"
44+
3645
3746@dataclass
3847class GraphState :
@@ -216,6 +225,9 @@ def __init__(self) -> None:
216225 self ._execution_timeout : Optional [float ] = None
217226 self ._node_timeout : Optional [float ] = None
218227 self ._reset_on_revisit : bool = False
228+ self ._id : str = _DEFAULT_GRAPH_ID
229+ self ._session_manager : Optional [SessionManager ] = None
230+ self ._hooks : Optional [list [HookProvider ]] = None
219231
220232 def add_node (self , executor : Agent | MultiAgentBase , node_id : str | None = None ) -> GraphNode :
221233 """Add an Agent or MultiAgentBase instance as a node to the graph."""
@@ -306,6 +318,33 @@ def set_node_timeout(self, timeout: float) -> "GraphBuilder":
306318 self ._node_timeout = timeout
307319 return self
308320
321+ def set_graph_id (self , graph_id : str ) -> "GraphBuilder" :
322+ """Set graph id.
323+
324+ Args:
325+ graph_id: Unique graph id default to uuid4
326+ """
327+ self ._id = graph_id
328+ return self
329+
330+ def set_session_manager (self , session_manager : SessionManager ) -> "GraphBuilder" :
331+ """Set session manager for the graph.
332+
333+ Args:
334+ session_manager: SessionManager instance
335+ """
336+ self ._session_manager = session_manager
337+ return self
338+
339+ def set_hook_providers (self , hooks : list [HookProvider ]) -> "GraphBuilder" :
340+ """Set hook providers for the graph.
341+
342+ Args:
343+ hooks: Customer hooks user passes in
344+ """
345+ self ._hooks = hooks
346+ return self
347+
309348 def build (self ) -> "Graph" :
310349 """Build and validate the graph with configured settings."""
311350 if not self .nodes :
@@ -324,13 +363,16 @@ def build(self) -> "Graph":
324363 self ._validate_graph ()
325364
326365 return Graph (
366+ id = self ._id ,
327367 nodes = self .nodes .copy (),
328368 edges = self .edges .copy (),
329369 entry_points = self .entry_points .copy (),
330370 max_node_executions = self ._max_node_executions ,
331371 execution_timeout = self ._execution_timeout ,
332372 node_timeout = self ._node_timeout ,
333373 reset_on_revisit = self ._reset_on_revisit ,
374+ session_manager = self ._session_manager ,
375+ hooks = self ._hooks ,
334376 )
335377
336378 def _validate_graph (self ) -> None :
@@ -358,6 +400,10 @@ def __init__(
358400 execution_timeout : Optional [float ] = None ,
359401 node_timeout : Optional [float ] = None ,
360402 reset_on_revisit : bool = False ,
403+ session_manager : Optional [SessionManager ] = None ,
404+ hooks : Optional [list [HookProvider ]] = None ,
405+ * ,
406+ id : Optional [str ] = None ,
361407 ) -> None :
362408 """Initialize Graph with execution limits and reset behavior.
363409
@@ -369,11 +415,15 @@ def __init__(
369415 execution_timeout: Total execution timeout in seconds (default: None - no limit)
370416 node_timeout: Individual node timeout in seconds (default: None - no limit)
371417 reset_on_revisit: Whether to reset node state when revisited (default: False)
418+ session_manager: Session manager for persisting graph state and execution history (default: None)
419+ hooks: List of hook providers for monitoring and extending graph execution behavior (default: None)
420+ id: Unique graph id (default: None)
372421 """
373422 super ().__init__ ()
374423
375424 # Validate nodes for duplicate instances
376425 self ._validate_graph (nodes )
426+ self .id = id or _DEFAULT_GRAPH_ID
377427
378428 self .nodes = nodes
379429 self .edges = edges
@@ -384,6 +434,18 @@ def __init__(
384434 self .reset_on_revisit = reset_on_revisit
385435 self .state = GraphState ()
386436 self .tracer = get_tracer ()
437+ self .session_manager = session_manager
438+ self .hooks = HookRegistry ()
439+ if self .session_manager :
440+ self .hooks .add_hook (self .session_manager )
441+ if hooks :
442+ for hook in hooks :
443+ self .hooks .add_hook (hook )
444+
445+ self ._resume_next_nodes : list [GraphNode ] = []
446+ self ._resume_from_session = False
447+
448+ self .hooks .invoke_callbacks (MultiAgentInitializedEvent (self ))
387449
388450 def __call__ (
389451 self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
@@ -418,16 +480,20 @@ async def invoke_async(
418480
419481 logger .debug ("task=<%s> | starting graph execution" , task )
420482
421- # Initialize state
422483 start_time = time .time ()
423- self .state = GraphState (
424- status = Status .EXECUTING ,
425- task = task ,
426- total_nodes = len (self .nodes ),
427- edges = [(edge .from_node , edge .to_node ) for edge in self .edges ],
428- entry_points = list (self .entry_points ),
429- start_time = start_time ,
430- )
484+ if not self ._resume_from_session :
485+ # Initialize state
486+ self .state = GraphState (
487+ status = Status .EXECUTING ,
488+ task = task ,
489+ total_nodes = len (self .nodes ),
490+ edges = [(edge .from_node , edge .to_node ) for edge in self .edges ],
491+ entry_points = list (self .entry_points ),
492+ start_time = start_time ,
493+ )
494+ else :
495+ self .state .status = Status .EXECUTING
496+ self .state .start_time = start_time
431497
432498 span = self .tracer .start_multiagent_span (task , "graph" )
433499 with trace_api .use_span (span , end_on_exit = True ):
@@ -455,6 +521,9 @@ async def invoke_async(
455521 raise
456522 finally :
457523 self .state .execution_time = round ((time .time () - start_time ) * 1000 )
524+ self .hooks .invoke_callbacks (AfterMultiAgentInvocationEvent (self ))
525+ self ._resume_from_session = False
526+ self ._resume_next_nodes .clear ()
458527 return self ._build_result ()
459528
460529 def _validate_graph (self , nodes : dict [str , GraphNode ]) -> None :
@@ -471,7 +540,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
471540
472541 async def _execute_graph (self , invocation_state : dict [str , Any ]) -> None :
473542 """Unified execution flow with conditional routing."""
474- ready_nodes = list (self .entry_points )
543+ ready_nodes = self . _resume_next_nodes if self . _resume_from_session else list (self .entry_points )
475544
476545 while ready_nodes :
477546 # Check execution limits before continuing
@@ -608,6 +677,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
608677 node .node_id ,
609678 self .node_timeout ,
610679 )
680+ self .hooks .invoke_callbacks (AfterNodeCallEvent (self , node .node_id , invocation_state ))
611681 raise Exception (timeout_msg ) from None
612682
613683 # Mark as completed
@@ -621,6 +691,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
621691 # Accumulate metrics
622692 self ._accumulate_metrics (node_result )
623693
694+ self .hooks .invoke_callbacks (AfterNodeCallEvent (self , node .node_id , invocation_state ))
695+
624696 logger .debug (
625697 "node_id=<%s>, execution_time=<%dms> | node completed successfully" , node .node_id , node .execution_time
626698 )
@@ -644,6 +716,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
644716 node .execution_time = execution_time
645717 self .state .failed_nodes .add (node )
646718 self .state .results [node .node_id ] = node_result # Store in results for consistency
719+ self .hooks .invoke_callbacks (AfterNodeCallEvent (self , node .node_id , invocation_state ))
647720
648721 raise
649722
@@ -731,3 +804,94 @@ def _build_result(self) -> GraphResult:
731804 edges = self .state .edges ,
732805 entry_points = self .state .entry_points ,
733806 )
807+
808+ def serialize_state (self ) -> dict [str , Any ]:
809+ """Serialize the current graph state to a dictionary."""
810+ status_str = self .state .status .value
811+ compute_nodes = self ._compute_ready_nodes_for_resume ()
812+ next_nodes = [n .node_id for n in compute_nodes ] if compute_nodes else []
813+ return {
814+ "type" : "graph" ,
815+ "id" : self .id ,
816+ "status" : status_str ,
817+ "completed_nodes" : [n .node_id for n in self .state .completed_nodes ],
818+ "failed_nodes" : [n .node_id for n in self .state .failed_nodes ],
819+ "node_results" : {k : v .to_dict () for k , v in (self .state .results or {}).items ()},
820+ "next_nodes_to_execute" : next_nodes ,
821+ "current_task" : self .state .task ,
822+ "execution_order" : [n .node_id for n in self .state .execution_order ],
823+ }
824+
825+ def deserialize_state (self , payload : dict [str , Any ]) -> None :
826+ """Restore graph state from a session dict and prepare for execution.
827+
828+ This method handles two scenarios:
829+ 1. If the persisted status is COMPLETED, FAILED resets all nodes and graph state
830+ to allow re-execution from the beginning.
831+ 2. Otherwise, restores the persisted state and prepares to resume execution
832+ from the next ready nodes.
833+
834+ Args:
835+ payload: Dictionary containing persisted state data including status,
836+ completed nodes, results, and next nodes to execute.
837+ """
838+ if not payload .get ("next_nodes_to_execute" ):
839+ # Reset all nodes
840+ for node in self .nodes .values ():
841+ node .reset_executor_state ()
842+ # Reset graph state
843+ self .state = GraphState ()
844+ self ._resume_from_session = False
845+ return
846+ else :
847+ self ._from_dict (payload )
848+ self ._resume_from_session = True
849+
850+ # Helper functions for serialize and deserialize
851+ def _compute_ready_nodes_for_resume (self ) -> list [GraphNode ]:
852+ if self .state .status == Status .PENDING :
853+ return []
854+ ready_nodes : list [GraphNode ] = []
855+ completed_nodes = set (self .state .completed_nodes )
856+ for node in self .nodes .values ():
857+ if node in completed_nodes :
858+ continue
859+ incoming = [e for e in self .edges if e .to_node is node ]
860+ if not incoming :
861+ ready_nodes .append (node )
862+ elif all (e .from_node in completed_nodes and e .should_traverse (self .state ) for e in incoming ):
863+ ready_nodes .append (node )
864+
865+ return ready_nodes
866+
867+ def _from_dict (self , payload : dict [str , Any ]) -> None :
868+ self .state .status = Status (payload ["status" ])
869+ # Hydrate completed nodes & results
870+ raw_results = payload .get ("node_results" ) or {}
871+ results : dict [str , NodeResult ] = {}
872+ for node_id , entry in raw_results .items ():
873+ if node_id not in self .nodes :
874+ continue
875+ try :
876+ results [node_id ] = NodeResult .from_dict (entry )
877+ except Exception :
878+ logger .exception ("Failed to hydrate NodeResult for node_id=%s; skipping." , node_id )
879+ raise
880+ self .state .results = results
881+
882+ self .state .failed_nodes = set (payload .get ("failed_nodes" ) or [])
883+
884+ # Restore completed nodes from persisted data
885+ completed_node_ids = payload .get ("completed_nodes" ) or []
886+ self .state .completed_nodes = {self .nodes [node_id ] for node_id in completed_node_ids if node_id in self .nodes }
887+
888+ # Execution order (only nodes that still exist)
889+ order_node_ids = payload .get ("execution_order" ) or []
890+ self .state .execution_order = [self .nodes [node_id ] for node_id in order_node_ids if node_id in self .nodes ]
891+
892+ # Task
893+ self .state .task = payload .get ("current_task" , self .state .task )
894+
895+ # next nodes to execute
896+ next_nodes = [self .nodes [nid ] for nid in (payload .get ("next_nodes_to_execute" ) or []) if nid in self .nodes ]
897+ self ._resume_next_nodes = next_nodes
0 commit comments