Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions flo_ai/flo_ai/arium/arium.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flo_ai.tool.base_tool import Tool
from flo_ai.arium.models import StartNode, EndNode
from flo_ai.arium.events import AriumEventType, AriumEvent
from flo_ai.arium.nodes import AriumNode, ForEachNode
from flo_ai.utils.logger import logger
from flo_ai.utils.variable_extractor import (
extract_variables_from_inputs,
Expand Down Expand Up @@ -73,14 +74,16 @@ async def run(

# Execute the workflow with event support
result = await self._execute_graph(
resolved_inputs, event_callback, events_filter
resolved_inputs, event_callback, events_filter, variables
)

# Emit workflow completed event
self._emit_event(
AriumEventType.WORKFLOW_COMPLETED, event_callback, events_filter
)

self.memory = MessageMemory() # cleanup the graph (if used as AriumNode multiple times in graph, then the same instance is used for now hence we need to cleanup memory)

return result

except Exception as e:
Expand Down Expand Up @@ -118,6 +121,7 @@ async def _execute_graph(
inputs: List[str | ImageMessage | DocumentMessage],
event_callback: Optional[Callable[[AriumEvent], None]] = None,
events_filter: Optional[List[AriumEventType]] = None,
variables: Optional[Dict[str, Any]] = None,
):
[self.memory.add(msg) for msg in inputs]

Expand Down Expand Up @@ -162,11 +166,16 @@ async def _execute_graph(
)
# execute current node
result = await self._execute_node(
current_node, event_callback, events_filter
current_node, event_callback, events_filter, variables
)

# update results to memory
self._add_to_memory(result)
if isinstance(result, List): # for each node will give results array
for item in result:
# update each item in result to memory
self._add_to_memory(item)
else:
# update results to memory
self._add_to_memory(result)

# find next node post current node
# Prepare execution context for router functions
Expand Down Expand Up @@ -301,6 +310,7 @@ async def _execute_node(
node: Agent | Tool | StartNode | EndNode,
event_callback: Optional[Callable[[AriumEvent], None]] = None,
events_filter: Optional[List[AriumEventType]] = None,
variables: Optional[Dict[str, Any]] = None,
):
"""
Execute a single node with optional event emission.
Expand All @@ -318,6 +328,10 @@ async def _execute_node(
node_type = 'agent'
elif isinstance(node, Tool):
node_type = 'tool'
elif isinstance(node, ForEachNode):
node_type = 'foreach'
elif isinstance(node, AriumNode):
node_type = 'arium'
elif isinstance(node, StartNode):
node_type = 'start'
elif isinstance(node, EndNode):
Expand All @@ -342,7 +356,16 @@ async def _execute_node(
# Variables are already resolved, pass empty dict to avoid re-processing
result = await node.run(self.memory.get(), variables={})
elif isinstance(node, Tool):
result = await node.execute()
# result = await node.execute() # as Tool is also an ExecutableNode now
result = await node.run(inputs=[], variables=None)
elif isinstance(node, ForEachNode):
result = await node.run(
inputs=self.memory.get(),
variables=variables,
)
elif isinstance(node, AriumNode):
# AriumNode execution
result = await node.run(inputs=self.memory.get(), variables=variables)
elif isinstance(node, StartNode):
result = None
elif isinstance(node, EndNode):
Expand Down
14 changes: 10 additions & 4 deletions flo_ai/flo_ai/arium/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import inspect
from functools import partial
from flo_ai.arium.nodes import AriumNode, ForEachNode
from flo_ai.arium.protocols import ExecutableNode
from flo_ai.models.agent import Agent
from flo_ai.tool.base_tool import Tool
from flo_ai.utils.logger import logger
Expand All @@ -12,13 +14,13 @@ class BaseArium:
def __init__(self):
self.start_node_name = '__start__'
self.end_node_names: set = set() # Support multiple end nodes
self.nodes: Dict[str, Agent | Tool | StartNode | EndNode] = dict()
self.nodes: Dict[str, ExecutableNode | StartNode | EndNode] = dict()
self.edges: Dict[str, Edge] = dict()

def add_nodes(self, agents: List[Agent | Tool | StartNode | EndNode]):
def add_nodes(self, agents: List[ExecutableNode | StartNode | EndNode]):
self.nodes.update({agent.name: agent for agent in agents})

def start_at(self, node: Agent | Tool | StartNode | EndNode):
def start_at(self, node: ExecutableNode):
start_node = StartNode()
if start_node.name in self.nodes:
raise ValueError(f'Start node {start_node.name} already exists')
Expand All @@ -27,7 +29,7 @@ def start_at(self, node: Agent | Tool | StartNode | EndNode):
router_fn=partial(default_router, to_node=node.name), to_nodes=[node.name]
)

def add_end_to(self, node: Agent | Tool | StartNode | EndNode):
def add_end_to(self, node: ExecutableNode):
# Create a unique end node name for this specific node
end_node_name = f'__end__{node.name}__'
end_node = EndNode()
Expand Down Expand Up @@ -378,5 +380,9 @@ def _get_node_type(self, node) -> str:
return 'agent'
elif isinstance(node, Tool):
return 'tool'
elif isinstance(node, ForEachNode):
return 'foreach'
elif isinstance(node, AriumNode):
return 'arium'
else:
return 'unknown'
Loading