Skip to content
2 changes: 2 additions & 0 deletions src/agents/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .openai_conversations_session import OpenAIConversationsSession
from .session import Session, SessionABC
from .sqlite_session import SQLiteSession
from .util import SessionInputCallback

__all__ = [
"Session",
"SessionABC",
"SessionInputCallback",
"SQLiteSession",
"OpenAIConversationsSession",
]
20 changes: 20 additions & 0 deletions src/agents/memory/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from typing import Callable

from ..items import TResponseInputItem
from ..util._types import MaybeAwaitable

SessionInputCallback = Callable[
[list[TResponseInputItem], list[TResponseInputItem]],
MaybeAwaitable[list[TResponseInputItem]],
]
"""A function that combines session history with new input items.

Args:
history_items: The list of items from the session history.
new_items: The list of new input items for the current turn.

Returns:
A list of combined items to be used as input for the agent. Can be sync or async.
"""
51 changes: 37 additions & 14 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
from .lifecycle import RunHooks
from .logger import logger
from .memory import Session
from .memory import Session, SessionInputCallback
from .model_settings import ModelSettings
from .models.interface import Model, ModelProvider
from .models.multi_provider import MultiProvider
Expand Down Expand Up @@ -179,6 +179,13 @@ class RunConfig:
An optional dictionary of additional metadata to include with the trace.
"""

session_input_callback: SessionInputCallback | None = None
"""Defines how to handle session history when new input is provided.
- `None` (default): The new input is appended to the session history.
- `SessionInputCallback`: A custom function that receives the history and new input, and
returns the desired combined list of items.
"""

call_model_input_filter: CallModelInputFilter | None = None
"""
Optional callback that is invoked immediately before calling the model. It receives the current
Expand Down Expand Up @@ -412,7 +419,9 @@ async def run(
run_config = RunConfig()

# Prepare input with session if enabled
prepared_input = await self._prepare_input_with_session(input, session)
prepared_input = await self._prepare_input_with_session(
input, session, run_config.session_input_callback
)

tool_use_tracker = AgentToolUseTracker()

Expand Down Expand Up @@ -780,7 +789,9 @@ async def _start_streaming(

try:
# Prepare input with session if enabled
prepared_input = await AgentRunner._prepare_input_with_session(starting_input, session)
prepared_input = await AgentRunner._prepare_input_with_session(
starting_input, session, run_config.session_input_callback
)

# Update the streamed result with the prepared input
streamed_result.input = prepared_input
Expand Down Expand Up @@ -1473,18 +1484,18 @@ async def _prepare_input_with_session(
cls,
input: str | list[TResponseInputItem],
session: Session | None,
session_input_callback: SessionInputCallback | None,
) -> str | list[TResponseInputItem]:
"""Prepare input by combining it with session history if enabled."""
if session is None:
return input

# Validate that we don't have both a session and a list input, as this creates
# ambiguity about whether the list should append to or replace existing session history
if isinstance(input, list):
# If the user doesn't explicitly specify a mode, raise an error
if isinstance(input, list) and not session_input_callback:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asking users to pass session_input_callback is a breaking change, plus not having the callback should be fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn’t seem like a breaking change to me, because it already raises an error if the user provides a list as input.
The difference now is that it should only raise an error if the input is a list and session_input_callback is None. Otherwise, it should use that function to handle the input. Let me know what do you think about it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes. Thanks for clarifying this. The current error message should be better than the current one

raise UserError(
"Cannot provide both a session and a list of input items. "
"When using session memory, provide only a string input to append to the "
"conversation, or use session=None and provide a list to manually manage "
"You must specify the `session_input_callback` in the `RunConfig`. "
"Otherwise, when using session memory, provide only a string input to append to "
"the conversation, or use session=None and provide a list to manually manage "
"conversation history."
)

Expand All @@ -1494,10 +1505,18 @@ async def _prepare_input_with_session(
# Convert input to list format
new_input_list = ItemHelpers.input_to_new_input_list(input)

# Combine history with new input
combined_input = history + new_input_list

return combined_input
if session_input_callback is None:
return history + new_input_list
elif callable(session_input_callback):
res = session_input_callback(history, new_input_list)
if inspect.isawaitable(res):
return await res
return res
else:
raise UserError(
f"Invalid `session_input_callback` value: {session_input_callback}. "
"Choose between `None` or a custom callable function."
)

@classmethod
async def _save_result_to_session(
Expand All @@ -1506,7 +1525,11 @@ async def _save_result_to_session(
original_input: str | list[TResponseInputItem],
new_items: list[RunItem],
) -> None:
"""Save the conversation turn to session."""
"""
Save the conversation turn to session.
It does not account for any filtering or modification performed by
`RunConfig.session_input_callback`.
"""
if session is None:
return

Expand Down
50 changes: 48 additions & 2 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from agents import Agent, Runner, SQLiteSession, TResponseInputItem
from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem
from agents.exceptions import UserError

from .fake_model import FakeModel
Expand Down Expand Up @@ -394,11 +394,57 @@ async def test_session_memory_rejects_both_session_and_list_input(runner_method)
await run_agent_async(runner_method, agent, list_input, session=session)

# Verify the error message explains the issue
assert "Cannot provide both a session and a list of input items" in str(exc_info.value)
assert "You must specify the `session_input_callback` in" in str(exc_info.value)
assert "manually manage conversation history" in str(exc_info.value)

session.close()


@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
@pytest.mark.asyncio
async def test_session_callback_prepared_input(runner_method):
"""Test if the user passes a list of items and want to append them."""
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "test_memory.db"

model = FakeModel()
agent = Agent(name="test", model=model)

# Session
session_id = "session_1"
session = SQLiteSession(session_id, db_path)

# Add first messages manually
initial_history: list[TResponseInputItem] = [
{"role": "user", "content": "Hello there."},
{"role": "assistant", "content": "Hi, I'm here to assist you."},
]
await session.add_items(initial_history)

def filter_assistant_messages(history, new_input):
# Only include user messages from history
return [item for item in history if item["role"] == "user"] + new_input

new_turn_input = [{"role": "user", "content": "What your name?"}]
model.set_next_output([get_text_message("I'm gpt-4o")])

# Run the agent with the callable
await run_agent_async(
runner_method,
agent,
new_turn_input,
session=session,
run_config=RunConfig(session_input_callback=filter_assistant_messages),
)

expected_model_input = [
initial_history[0], # From history
new_turn_input[0], # New input
]

assert len(model.last_turn_args["input"]) == 2
assert model.last_turn_args["input"] == expected_model_input

@pytest.mark.asyncio
async def test_sqlite_session_unicode_content():
"""Test that session correctly stores and retrieves unicode/non-ASCII content."""
Expand Down