22
22
from typing import Any , Callable , Tuple , cast
23
23
24
24
from ..agent import Agent , AgentResult
25
+ from ..types .content import ContentBlock
25
26
from ..types .event_loop import Metrics , Usage
26
27
from .base import MultiAgentBase , MultiAgentResult , NodeResult , Status
27
28
@@ -42,12 +43,14 @@ class GraphState:
42
43
Entry point nodes receive this task as their input if they have no dependencies.
43
44
"""
44
45
46
+ # Task (with default empty string)
47
+ task : str | list [ContentBlock ] = ""
48
+
45
49
# Execution state
46
50
status : Status = Status .PENDING
47
51
completed_nodes : set ["GraphNode" ] = field (default_factory = set )
48
52
failed_nodes : set ["GraphNode" ] = field (default_factory = set )
49
53
execution_order : list ["GraphNode" ] = field (default_factory = list )
50
- task : str = ""
51
54
52
55
# Results
53
56
results : dict [str , NodeResult ] = field (default_factory = dict )
@@ -247,7 +250,7 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
247
250
self .entry_points = entry_points
248
251
self .state = GraphState ()
249
252
250
- def execute (self , task : str ) -> GraphResult :
253
+ def execute (self , task : str | list [ ContentBlock ] ) -> GraphResult :
251
254
"""Execute task synchronously."""
252
255
253
256
def execute () -> GraphResult :
@@ -257,7 +260,7 @@ def execute() -> GraphResult:
257
260
future = executor .submit (execute )
258
261
return future .result ()
259
262
260
- async def execute_async (self , task : str ) -> GraphResult :
263
+ async def execute_async (self , task : str | list [ ContentBlock ] ) -> GraphResult :
261
264
"""Execute the graph asynchronously."""
262
265
logger .debug ("task=<%s> | starting graph execution" , task )
263
266
@@ -435,8 +438,8 @@ def _accumulate_metrics(self, node_result: NodeResult) -> None:
435
438
self .state .accumulated_metrics ["latencyMs" ] += node_result .accumulated_metrics .get ("latencyMs" , 0 )
436
439
self .state .execution_count += node_result .execution_count
437
440
438
- def _build_node_input (self , node : GraphNode ) -> str :
439
- """Build input text for a node based on dependency outputs."""
441
+ def _build_node_input (self , node : GraphNode ) -> list [ ContentBlock ] :
442
+ """Build input for a node based on dependency outputs."""
440
443
# Get satisfied dependencies
441
444
dependency_results = {}
442
445
for edge in self .edges :
@@ -449,21 +452,36 @@ def _build_node_input(self, node: GraphNode) -> str:
449
452
dependency_results [edge .from_node .node_id ] = self .state .results [edge .from_node .node_id ]
450
453
451
454
if not dependency_results :
452
- return self .state .task
455
+ # No dependencies - return task as ContentBlocks
456
+ if isinstance (self .state .task , str ):
457
+ return [ContentBlock (text = self .state .task )]
458
+ else :
459
+ return self .state .task
453
460
454
461
# Combine task with dependency outputs
455
- input_parts = [f"Original Task: { self .state .task } " , "\n Inputs from previous nodes:" ]
462
+ node_input = []
463
+
464
+ # Add original task
465
+ if isinstance (self .state .task , str ):
466
+ node_input .append (ContentBlock (text = f"Original Task: { self .state .task } " ))
467
+ else :
468
+ # Add task content blocks with a prefix
469
+ node_input .append (ContentBlock (text = "Original Task:" ))
470
+ node_input .extend (self .state .task )
471
+
472
+ # Add dependency outputs
473
+ node_input .append (ContentBlock (text = "\n Inputs from previous nodes:" ))
456
474
457
475
for dep_id , node_result in dependency_results .items ():
458
- input_parts .append (f"\n From { dep_id } :" )
476
+ node_input .append (ContentBlock ( text = f"\n From { dep_id } :" ) )
459
477
# Get all agent results from this node (flattened if nested)
460
478
agent_results = node_result .get_agent_results ()
461
479
for result in agent_results :
462
480
agent_name = getattr (result , "agent_name" , "Agent" )
463
481
result_text = str (result )
464
- input_parts .append (f" - { agent_name } : { result_text } " )
482
+ node_input .append (ContentBlock ( text = f" - { agent_name } : { result_text } " ) )
465
483
466
- return " \n " . join ( input_parts )
484
+ return node_input
467
485
468
486
def _build_result (self ) -> GraphResult :
469
487
"""Build graph result from current state."""
0 commit comments