Skip to content

Commit 26e90c8

Browse files
authored
Observer: Create callbacks for agent, router & tools (#54)
* Setting up callbacks * Adding callbacks to models * Renaming to model_name * Completing callbacks for agent start and ends * Fixing precommit hook issues * Updating poetry lock * Removing line * Fix for properly registering the agent models
1 parent f581f3f commit 26e90c8

19 files changed

+722
-317
lines changed

examples/llm_extensibility.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pydantic import BaseModel, Field
44
from langchain_openai import ChatOpenAI
55
from flo_ai.tools.flo_tool import flotool
6+
from flo_ai.state.flo_callbacks import flo_agent_callback, FloCallbackResponse
67

78
from dotenv import load_dotenv
89
import warnings
@@ -38,7 +39,15 @@ def email_tool(to: str, message: str):
3839
return f'Email sent successfully to: {to}'
3940

4041

42+
@flo_agent_callback
43+
def agent_callback(response: FloCallbackResponse):
44+
print('------------- START AGENT CALLBACK -----------')
45+
print(response)
46+
print('------------- END AGENT CALLBACK -----------')
47+
48+
4149
session.register_tool('SendEmailTool', email_tool)
50+
session.register_callback(agent_callback)
4251

4352
agent_yaml = """
4453
apiVersion: flo/alpha-v1

examples/simple_blogging_team.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
llm = ChatOpenAI(temperature=0, model_name='gpt-4o-mini')
3535
session = (
36-
FloSession(llm, log_level='INFO')
36+
FloSession(llm)
3737
.register_tool(name='TavilySearchResults', tool=TavilySearchResults())
3838
.register_tool(
3939
name='DummyTool',
@@ -43,5 +43,5 @@
4343

4444
Flo.set_log_level('INFO')
4545
flo: Flo = Flo.build(session, yaml=yaml_data)
46-
# data = flo.invoke(input_prompt)
46+
data = flo.invoke(input_prompt)
4747
# print((data['messages'][-1]).content)

flo_ai/common/flo_langchain_logger.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from langchain.callbacks.base import BaseCallbackHandler
33
from langchain.schema import AgentAction, AgentFinish, LLMResult
44
from flo_ai.common.flo_logger import get_logger
5+
from flo_ai.state.flo_callbacks import FloToolCallback
56

67

78
class FloLangchainLogger(BaseCallbackHandler):
8-
def __init__(self, session_id: str):
9+
def __init__(self, session_id: str, tool_callbacks: List[FloToolCallback] = []):
910
self.session_id = session_id
11+
self.tool_callbacks = tool_callbacks
1012

1113
def on_llm_start(
1214
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@@ -41,14 +43,20 @@ def on_tool_start(
4143
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
4244
) -> None:
4345
get_logger().debug(f'onToolStart: {input_str}', self)
46+
[
47+
x.on_tool_start(serialized['name'], kwargs['inputs'], kwargs)
48+
for x in self.tool_callbacks
49+
]
4450

4551
def on_tool_end(self, output: str, **kwargs: Any) -> None:
4652
get_logger().debug(f'onToolEnd: {output}', self)
53+
[x.on_tool_end(kwargs['name'], output, kwargs) for x in self.tool_callbacks]
4754

4855
def on_tool_error(
4956
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
5057
) -> None:
5158
get_logger().debug(f'onToolError: {error}', self)
59+
[x.on_tool_error(kwargs['name'], error, kwargs) for x in self.tool_callbacks]
5260

5361
def on_text(self, text: str, **kwargs: Any) -> None:
5462
get_logger().debug(f'onText: {text}', self)

flo_ai/core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
set_logger_internal,
1717
FloLogConfig,
1818
)
19+
from langchain.tools import StructuredTool
1920

2021

2122
class Flo:
@@ -90,11 +91,14 @@ def draw_to_file(self, filename: str, xray=True):
9091
def validate_invoke(self, session: FloSession):
9192
async_coroutines = filter(
9293
lambda x: (
93-
hasattr(x, 'coroutine') and asyncio.iscoroutinefunction(x.coroutine)
94+
isinstance(x, StructuredTool)
95+
and hasattr(x, 'coroutine')
96+
and asyncio.iscoroutinefunction(x.coroutine)
9497
),
9598
session.tools.values(),
9699
)
97-
if len(list(async_coroutines)) > 0:
100+
async_tools = list(async_coroutines)
101+
if len(async_tools) > 0:
98102
raise FloException(
99103
f"""You seem to have atleast one async tool registered in this session. Please use flo.async_invoke or flo.async_stream. Checkout {DOCUMENTATION_WEBSITE}"""
100104
)

flo_ai/factory/agent_factory.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,30 @@ def __create_agentic_agent(
6060
agent_model = AgentFactory.__resolve_model(session, agent.model)
6161
tools = [tool_map[tool.name] for tool in agent.tools]
6262
flo_agent: FloAgent = FloAgent.Builder(
63-
session, agent, tools, llm=agent_model, on_error=session.on_agent_error
63+
session,
64+
agent,
65+
tools,
66+
llm=agent_model,
67+
on_error=session.on_agent_error,
68+
model_name=agent.model,
6469
).build()
6570
return flo_agent
6671

6772
@staticmethod
6873
def __create_llm_agent(session: FloSession, agent: AgentConfig) -> FloLLMAgent:
6974
agent_model = AgentFactory.__resolve_model(session, agent.model)
70-
builder = FloLLMAgent.Builder(session, agent, llm=agent_model)
75+
builder = FloLLMAgent.Builder(
76+
session, agent, llm=agent_model, model_name=agent.model
77+
)
7178
llm_agent: FloLLMAgent = builder.build()
7279
return llm_agent
7380

7481
@staticmethod
7582
def __create_runnable_agent(session: FloSession, agent: AgentConfig) -> FloLLMAgent:
7683
runnable = session.tools[agent.tools[0].name]
77-
return FloToolAgent.Builder(session, agent, runnable).build()
84+
return FloToolAgent.Builder(
85+
session, agent, runnable, model_name=agent.model
86+
).build()
7887

7988
@staticmethod
8089
def __create_reflection_agent(

flo_ai/models/flo_agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,14 @@
1313

1414
class FloAgent(ExecutableFlo):
1515
def __init__(
16-
self, agent: Runnable, executor: AgentExecutor, config: AgentConfig
16+
self,
17+
agent: Runnable,
18+
executor: AgentExecutor,
19+
config: AgentConfig,
20+
model_nick_name: str,
1721
) -> None:
1822
super().__init__(config.name, executor, ExecutableType.agentic)
23+
self.model_name = model_nick_name
1924
self.agent: Runnable = (agent,)
2025
self.executor: AgentExecutor = executor
2126
self.config: AgentConfig = config
@@ -30,9 +35,11 @@ def __init__(
3035
role: Optional[str] = None,
3136
llm: Union[BaseLanguageModel, None] = None,
3237
on_error: Union[str, Callable] = True,
38+
model_name: Union[str, None] = 'default',
3339
) -> None:
3440
prompt: Union[ChatPromptTemplate, str] = config.job
3541
self.name: str = config.name
42+
self.model_name = model_name
3643
self.llm = llm if llm is not None else session.llm
3744
self.config = config
3845
system_prompts = (
@@ -60,4 +67,6 @@ def build(self) -> AgentExecutor:
6067
return_intermediate_steps=True,
6168
handle_parsing_errors=self.on_error,
6269
)
63-
return FloAgent(agent, executor, self.config)
70+
return FloAgent(
71+
agent, executor, self.config, model_nick_name=self.model_name
72+
)

flo_ai/models/flo_delegation_agent.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,29 @@
1010

1111

1212
class FloDelegatorAgent(ExecutableFlo):
13-
def __init__(self, executor: Runnable, config: AgentConfig) -> None:
13+
def __init__(
14+
self, executor: Runnable, config: AgentConfig, model_name: str
15+
) -> None:
1416
super().__init__(config.name, executor, ExecutableType.delegator)
1517
self.executor: Runnable = executor
1618
self.config: AgentConfig = config
19+
self.model_name = model_name
1720

1821
class Builder:
1922
def __init__(
2023
self,
2124
session: FloSession,
2225
agentConfig: AgentConfig,
2326
llm: Optional[BaseLanguageModel] = None,
27+
model_name: str = None,
2428
) -> None:
2529
self.config = agentConfig
2630
delegator_base_system_message = (
2731
'You are a delegator tasked with routing a conversation between the'
2832
' following {member_type}: {members}. Given the following rules,'
2933
' respond with the worker to act next '
3034
)
35+
self.model_name = model_name
3136
self.llm = session.llm if llm is None else llm
3237
self.options = [x.name for x in agentConfig.to]
3338
self.llm_router_prompt = ChatPromptTemplate.from_messages(
@@ -75,4 +80,6 @@ def build(self):
7580
| JsonOutputFunctionsParser()
7681
)
7782

78-
return FloDelegatorAgent(executor=chain, config=self.config)
83+
return FloDelegatorAgent(
84+
executor=chain, config=self.config, model_name=self.model_name
85+
)

flo_ai/models/flo_llm_agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,23 @@
1010

1111

1212
class FloLLMAgent(ExecutableFlo):
13-
def __init__(self, executor: Runnable, config: AgentConfig) -> None:
13+
def __init__(
14+
self, executor: Runnable, config: AgentConfig, model_name: str
15+
) -> None:
1416
super().__init__(config.name, executor, ExecutableType.llm)
1517
self.executor: Runnable = executor
1618
self.config: AgentConfig = config
19+
self.model_name: str = model_name
1720

1821
class Builder:
1922
def __init__(
2023
self,
2124
session: FloSession,
2225
config: AgentConfig,
2326
llm: Union[BaseLanguageModel, None] = None,
27+
model_name: str = None,
2428
) -> None:
29+
self.model_name = model_name
2530
prompt: Union[ChatPromptTemplate, str] = config.job
2631

2732
self.name: str = config.name
@@ -42,4 +47,4 @@ def __init__(
4247

4348
def build(self) -> Runnable:
4449
executor = self.prompt | self.llm | StrOutputParser()
45-
return FloLLMAgent(executor, self.config)
50+
return FloLLMAgent(executor, self.config, self.model_name)

flo_ai/models/flo_node.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from langchain_core.messages import HumanMessage
77
from flo_ai.yaml.config import AgentConfig, TeamConfig
88
from flo_ai.models.flo_executable import ExecutableType
9-
from typing import Union
9+
from flo_ai.state.flo_session import FloSession
10+
from typing import Union, Type, List
11+
from flo_ai.state.flo_callbacks import FloAgentCallback, FloRouterCallback, FloCallback
1012

1113

1214
class FloNode:
@@ -23,19 +25,26 @@ def __init__(
2325
self.config: Union[AgentConfig | TeamConfig] = config
2426

2527
class Builder:
28+
def __init__(self, session: FloSession) -> None:
29+
self.session = session
30+
2631
def build_from_agent(self, flo_agent: FloAgent) -> 'FloNode':
2732
agent_func = functools.partial(
2833
FloNode.Builder.__teamflo_agent_node,
2934
agent=flo_agent.runnable,
3035
name=flo_agent.name,
3136
agent_config=flo_agent.config,
37+
session=self.session,
38+
model_name=flo_agent.model_name,
3239
)
3340
return FloNode(agent_func, flo_agent.name, flo_agent.type, flo_agent.config)
3441

3542
def build_from_team(self, flo_team: FloRoutedTeam) -> 'FloNode':
3643
team_chain = (
3744
functools.partial(
38-
FloNode.Builder.__teamflo_team_node, members=flo_team.runnable.nodes
45+
FloNode.Builder.__teamflo_team_node,
46+
members=flo_team.runnable.nodes,
47+
session=self.session,
3948
)
4049
| flo_team.runnable
4150
)
@@ -56,6 +65,8 @@ def build_from_router(self, flo_router) -> 'FloNode':
5665
agent=flo_router.executor,
5766
name=flo_router.router_name,
5867
agent_config=flo_router.config,
68+
session=self.session,
69+
model_name=flo_router.model_name,
5970
)
6071
return FloNode(
6172
router_func, flo_router.router_name, flo_router.type, flo_router.config
@@ -67,20 +78,95 @@ def __teamflo_agent_node(
6778
agent: AgentExecutor,
6879
name: str,
6980
agent_config: AgentConfig,
81+
session: FloSession,
82+
model_name: str,
7083
):
71-
result = agent.invoke(state)
72-
output = result if isinstance(result, str) else result['output']
84+
agent_cbs: List[FloAgentCallback] = FloNode.Builder.__filter_callbacks(
85+
session, FloAgentCallback
86+
)
87+
flo_cbs: List[FloCallback] = FloNode.Builder.__filter_callbacks(
88+
session, FloCallback
89+
)
90+
[
91+
callback.on_agent_start(name, model_name, state['messages'], **{})
92+
for callback in agent_cbs
93+
]
94+
[
95+
callback.on_agent_start(name, model_name, state['messages'], **{})
96+
for callback in flo_cbs
97+
]
98+
try:
99+
result = agent.invoke(state)
100+
output = result if isinstance(result, str) else result['output']
101+
except Exception as e:
102+
[
103+
callback.on_agent_error(name, model_name, e, **{})
104+
for callback in agent_cbs
105+
]
106+
[
107+
callback.on_agent_error(name, model_name, e, **{})
108+
for callback in flo_cbs
109+
]
110+
raise e
111+
[
112+
callback.on_agent_end(name, model_name, output, **{})
113+
for callback in agent_cbs
114+
]
115+
[
116+
callback.on_agent_start(name, model_name, output, **{})
117+
for callback in flo_cbs
118+
]
73119
return {STATE_NAME_MESSAGES: [HumanMessage(content=output, name=name)]}
74120

121+
@staticmethod
122+
def __filter_callbacks(session: FloSession, type: Type):
123+
cbs = session.callbacks
124+
return list(filter(lambda callback: isinstance(callback, type), cbs))
125+
75126
@staticmethod
76127
def __teamflo_router_node(
77128
state: TeamFloAgentState,
78129
agent: AgentExecutor,
79130
name: str,
80131
agent_config: AgentConfig,
132+
session: FloSession,
133+
model_name: str,
81134
):
82-
result = agent.invoke(state)
83-
nextNode = result if isinstance(result, str) else result['next']
135+
agent_cbs: List[FloRouterCallback] = FloNode.Builder.__filter_callbacks(
136+
session, FloRouterCallback
137+
)
138+
flo_cbs: List[FloCallback] = FloNode.Builder.__filter_callbacks(
139+
session, FloCallback
140+
)
141+
[
142+
callback.on_router_start(name, model_name, state['messages'], **{})
143+
for callback in agent_cbs
144+
]
145+
[
146+
callback.on_router_start(name, model_name, state['messages'], **{})
147+
for callback in flo_cbs
148+
]
149+
try:
150+
result = agent.invoke(state)
151+
nextNode = result if isinstance(result, str) else result['next']
152+
except Exception as e:
153+
[
154+
callback.on_router_error(name, model_name, e, **{})
155+
for callback in agent_cbs
156+
]
157+
[
158+
callback.on_router_error(name, model_name, e, **{})
159+
for callback in flo_cbs
160+
]
161+
raise e
162+
[
163+
callback.on_router_end(name, model_name, nextNode, **{})
164+
for callback in agent_cbs
165+
]
166+
[
167+
callback.on_router_start(name, model_name, nextNode, **{})
168+
for callback in flo_cbs
169+
]
84170
return {'next': nextNode}
85171

86172
@staticmethod

0 commit comments

Comments
 (0)