66from langchain_core .messages import HumanMessage
77from flo_ai .yaml .config import AgentConfig , TeamConfig
88from flo_ai .models .flo_executable import ExecutableType
9- from typing import Union
9+ from flo_ai .state .flo_session import FloSession
10+ from typing import Union , Type , List
11+ from flo_ai .state .flo_callbacks import FloAgentCallback , FloRouterCallback , FloCallback
1012
1113
1214class FloNode :
@@ -23,19 +25,26 @@ def __init__(
2325 self .config : Union [AgentConfig | TeamConfig ] = config
2426
2527 class Builder :
28+ def __init__ (self , session : FloSession ) -> None :
29+ self .session = session
30+
2631 def build_from_agent (self , flo_agent : FloAgent ) -> 'FloNode' :
2732 agent_func = functools .partial (
2833 FloNode .Builder .__teamflo_agent_node ,
2934 agent = flo_agent .runnable ,
3035 name = flo_agent .name ,
3136 agent_config = flo_agent .config ,
37+ session = self .session ,
38+ model_name = flo_agent .model_name ,
3239 )
3340 return FloNode (agent_func , flo_agent .name , flo_agent .type , flo_agent .config )
3441
3542 def build_from_team (self , flo_team : FloRoutedTeam ) -> 'FloNode' :
3643 team_chain = (
3744 functools .partial (
38- FloNode .Builder .__teamflo_team_node , members = flo_team .runnable .nodes
45+ FloNode .Builder .__teamflo_team_node ,
46+ members = flo_team .runnable .nodes ,
47+ session = self .session ,
3948 )
4049 | flo_team .runnable
4150 )
@@ -56,6 +65,8 @@ def build_from_router(self, flo_router) -> 'FloNode':
5665 agent = flo_router .executor ,
5766 name = flo_router .router_name ,
5867 agent_config = flo_router .config ,
68+ session = self .session ,
69+ model_name = flo_router .model_name ,
5970 )
6071 return FloNode (
6172 router_func , flo_router .router_name , flo_router .type , flo_router .config
@@ -67,20 +78,95 @@ def __teamflo_agent_node(
6778 agent : AgentExecutor ,
6879 name : str ,
6980 agent_config : AgentConfig ,
81+ session : FloSession ,
82+ model_name : str ,
7083 ):
71- result = agent .invoke (state )
72- output = result if isinstance (result , str ) else result ['output' ]
84+ agent_cbs : List [FloAgentCallback ] = FloNode .Builder .__filter_callbacks (
85+ session , FloAgentCallback
86+ )
87+ flo_cbs : List [FloCallback ] = FloNode .Builder .__filter_callbacks (
88+ session , FloCallback
89+ )
90+ [
91+ callback .on_agent_start (name , model_name , state ['messages' ], ** {})
92+ for callback in agent_cbs
93+ ]
94+ [
95+ callback .on_agent_start (name , model_name , state ['messages' ], ** {})
96+ for callback in flo_cbs
97+ ]
98+ try :
99+ result = agent .invoke (state )
100+ output = result if isinstance (result , str ) else result ['output' ]
101+ except Exception as e :
102+ [
103+ callback .on_agent_error (name , model_name , e , ** {})
104+ for callback in agent_cbs
105+ ]
106+ [
107+ callback .on_agent_error (name , model_name , e , ** {})
108+ for callback in flo_cbs
109+ ]
110+ raise e
111+ [
112+ callback .on_agent_end (name , model_name , output , ** {})
113+ for callback in agent_cbs
114+ ]
115+ [
116+ callback .on_agent_start (name , model_name , output , ** {})
117+ for callback in flo_cbs
118+ ]
73119 return {STATE_NAME_MESSAGES : [HumanMessage (content = output , name = name )]}
74120
121+ @staticmethod
122+ def __filter_callbacks (session : FloSession , type : Type ):
123+ cbs = session .callbacks
124+ return list (filter (lambda callback : isinstance (callback , type ), cbs ))
125+
75126 @staticmethod
76127 def __teamflo_router_node (
77128 state : TeamFloAgentState ,
78129 agent : AgentExecutor ,
79130 name : str ,
80131 agent_config : AgentConfig ,
132+ session : FloSession ,
133+ model_name : str ,
81134 ):
82- result = agent .invoke (state )
83- nextNode = result if isinstance (result , str ) else result ['next' ]
135+ agent_cbs : List [FloRouterCallback ] = FloNode .Builder .__filter_callbacks (
136+ session , FloRouterCallback
137+ )
138+ flo_cbs : List [FloCallback ] = FloNode .Builder .__filter_callbacks (
139+ session , FloCallback
140+ )
141+ [
142+ callback .on_router_start (name , model_name , state ['messages' ], ** {})
143+ for callback in agent_cbs
144+ ]
145+ [
146+ callback .on_router_start (name , model_name , state ['messages' ], ** {})
147+ for callback in flo_cbs
148+ ]
149+ try :
150+ result = agent .invoke (state )
151+ nextNode = result if isinstance (result , str ) else result ['next' ]
152+ except Exception as e :
153+ [
154+ callback .on_router_error (name , model_name , e , ** {})
155+ for callback in agent_cbs
156+ ]
157+ [
158+ callback .on_router_error (name , model_name , e , ** {})
159+ for callback in flo_cbs
160+ ]
161+ raise e
162+ [
163+ callback .on_router_end (name , model_name , nextNode , ** {})
164+ for callback in agent_cbs
165+ ]
166+ [
167+ callback .on_router_start (name , model_name , nextNode , ** {})
168+ for callback in flo_cbs
169+ ]
84170 return {'next' : nextNode }
85171
86172 @staticmethod
0 commit comments