Skip to content

feat(agents): Add on_llm_start and on_llm_end Lifecycle Hooks #987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/agents/lifecycle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Generic
from typing import Any, Generic, Optional

from .agent import Agent
from .items import ModelResponse, TResponseInputItem
from .run_context import RunContextWrapper, TContext
from .tool import Tool

Expand All @@ -10,6 +11,25 @@ class RunHooks(Generic[TContext]):
override the methods you need.
"""

# Two new hook methods added to the RunHooks class to handle LLM start and end events.
# These methods allow you to perform actions just before and after the LLM call for an agent.
# This is useful for logging, monitoring, or modifying the context before and after the LLM call
async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
"""Called just before invoking the LLM for this agent."""
pass

async def on_llm_end(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], response: ModelResponse
) -> None:
"""Called immediately after the LLM call returns for this agent."""
pass

async def on_agent_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
) -> None:
Expand Down Expand Up @@ -103,3 +123,22 @@ async def on_tool_end(
) -> None:
"""Called after a tool is invoked."""
pass

async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
"""Called immediately before the agent issues an LLM call."""
pass

async def on_llm_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
"""Called immediately after the agent receives the LLM response."""
pass
6 changes: 6 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,9 @@ async def _get_new_response(
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
# If the agent has hooks, we need to call them before and after the LLM call
if agent.hooks:
await agent.hooks.on_llm_start(context_wrapper, agent, system_prompt, input)

new_response = await model.get_response(
system_instructions=system_prompt,
Expand All @@ -1081,6 +1084,9 @@ async def _get_new_response(
previous_response_id=previous_response_id,
prompt=prompt_config,
)
# If the agent has hooks, we need to call them after the LLM call
if agent.hooks:
await agent.hooks.on_llm_end(context_wrapper, agent, new_response)

context_wrapper.usage.add(new_response.usage)

Expand Down
85 changes: 85 additions & 0 deletions tests/test_agent_llm_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from collections import defaultdict
from typing import Any, Optional

import pytest

from agents.agent import Agent
from agents.items import ModelResponse, TResponseInputItem
from agents.lifecycle import AgentHooks
from agents.run import Runner
from agents.run_context import RunContextWrapper, TContext
from agents.tool import Tool

from .fake_model import FakeModel
from .test_responses import (
get_function_tool,
get_text_message,
)


class AgentHooksForTests(AgentHooks):
def __init__(self):
self.events: dict[str, int] = defaultdict(int)

def reset(self):
self.events.clear()

async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
self.events["on_start"] += 1

async def on_end(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any
) -> None:
self.events["on_end"] += 1

async def on_handoff(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]
) -> None:
self.events["on_handoff"] += 1

async def on_tool_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool
) -> None:
self.events["on_tool_start"] += 1

async def on_tool_end(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
tool: Tool,
result: str,
) -> None:
self.events["on_tool_end"] += 1

# NEW: LLM hooks
async def on_llm_start(
self,
context: RunContextWrapper[TContext],
agent: Agent[TContext],
system_prompt: Optional[str],
input_items: list[TResponseInputItem],
) -> None:
self.events["on_llm_start"] += 1

async def on_llm_end(
self,
ccontext: RunContextWrapper[TContext],
agent: Agent[TContext],
response: ModelResponse,
) -> None:
self.events["on_llm_end"] += 1


# Example test using the above hooks:
@pytest.mark.asyncio
async def test_non_streamed_agent_hooks_with_llm():
hooks = AgentHooksForTests()
model = FakeModel()
agent = Agent(
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
)
# Simulate a single LLM call producing an output:
model.set_next_output([get_text_message("hello")])
await Runner.run(agent, input="hello")
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}