Skip to content

feat: Agent State #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ConversationManager,
SlidingWindowConversationManager,
)
from .state import AgentState

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -223,6 +224,7 @@ def __init__(
*,
name: Optional[str] = None,
description: Optional[str] = None,
state: Optional[Union[AgentState, dict]] = None,
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -259,6 +261,8 @@ def __init__(
Defaults to None.
description: description of what the Agent does
Defaults to None.
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
Defaults to an empty AgentState object.

Raises:
ValueError: If max_parallel_tools is less than 1.
Expand Down Expand Up @@ -319,6 +323,20 @@ def __init__(
# Initialize tracer instance (no-op if not configured)
self.tracer = get_tracer()
self.trace_span: Optional[trace.Span] = None

# Initialize agent state management
if state is not None:
if isinstance(state, dict):
self.state = AgentState(state)
elif isinstance(state, AgentState):
print("HERE!")
print(type(state))
self.state = state
else:
raise ValueError("state must be an AgentState object or a dict")
else:
self.state = AgentState()

self.tool_caller = Agent.ToolCaller(self)
self.name = name
self.description = description
Expand Down
96 changes: 96 additions & 0 deletions src/strands/agent/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Agent state management."""

import json
from typing import Any, Dict, Optional


class AgentState:
"""Represents an Agent's stateful information outside of context provided to a model.

Provides a key-value store for agent state with JSON serialization validation and persistence support.
Key features:
- JSON serialization validation on assignment
- Get/set/delete operations
"""

def __init__(self, initial_state: Optional[Dict[str, Dict[str, Any]]] = None):
"""Initialize AgentState."""
self._state: Dict[str, Dict[str, Any]]
if initial_state:
self._validate_json_serializable(initial_state)
self._state = initial_state.copy()
else:
self._state = {}

def set(self, key: str, value: Any) -> None:
"""Set a value in the state.

Args:
key: The key to store the value under
value: The value to store (must be JSON serializable)

Raises:
ValueError: If key is invalid, or if value is not JSON serializable
"""
self._validate_key(key)
self._validate_json_serializable(value)

self._state[key] = value

def get(self, key: Optional[str] = None) -> Any:
"""Get a value or entire state.

Args:
key: The key to retrieve (if None, returns entire state object)

Returns:
The stored value, entire state dict, or None if not found
"""
if key is None:
return self._state.copy()
else:
# Return specific key
return self._state.get(key)

def delete(self, key: str) -> None:
"""Delete a specific key from the state.

Args:
key: The key to delete
"""
self._validate_key(key)

self._state.pop(key, None)

def _validate_key(self, key: str) -> None:
"""Validate that a key is valid.

Args:
key: The key to validate

Raises:
ValueError: If key is invalid
"""
if key is None:
raise ValueError("Key cannot be None")
if not isinstance(key, str):
raise ValueError("Key must be a string")
if not key.strip():
raise ValueError("Key cannot be empty")

def _validate_json_serializable(self, value: Any) -> None:
"""Validate that a value is JSON serializable.

Args:
value: The value to validate

Raises:
ValueError: If value is not JSON serializable
"""
try:
json.dumps(value)
except (TypeError, ValueError) as e:
raise ValueError(
f"Value is not JSON serializable: {type(value).__name__}. "
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
) from e
12 changes: 12 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,3 +1314,15 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_
kwargs = mock_event_loop_cycle.call_args[1]
assert "event_loop_parent_span" in kwargs
assert kwargs["event_loop_parent_span"] == mock_span


def test_non_dict_throws_error():
with pytest.raises(ValueError, match="state must be an AgentState object or a dict"):
agent = Agent(state={"object", object()})
print(agent.state)


def test_non_json_serializable_state_throws_error():
with pytest.raises(ValueError, match="Value is not JSON serializable"):
agent = Agent(state={"object": object()})
print(agent.state)
111 changes: 111 additions & 0 deletions tests/strands/agent/test_agent_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Tests for AgentState class."""

import pytest

from strands.agent.state import AgentState


def test_set_and_get():
"""Test basic set and get operations."""
state = AgentState()
state.set("key", "value")
assert state.get("key") == "value"


def test_get_nonexistent_key():
"""Test getting nonexistent key returns None."""
state = AgentState()
assert state.get("nonexistent") is None


def test_get_entire_state():
"""Test getting entire state when no key specified."""
state = AgentState()
state.set("key1", "value1")
state.set("key2", "value2")

result = state.get()
assert result == {"key1": "value1", "key2": "value2"}


def test_initialize_and_get_entire_state():
"""Test getting entire state when no key specified."""
state = AgentState({"key1": "value1", "key2": "value2"})

result = state.get()
assert result == {"key1": "value1", "key2": "value2"}


def test_initialize_with_error():
with pytest.raises(ValueError, match="not JSON serializable"):
AgentState({"object", object()})


def test_delete():
"""Test deleting keys."""
state = AgentState()
state.set("key1", "value1")
state.set("key2", "value2")

state.delete("key1")

assert state.get("key1") is None
assert state.get("key2") == "value2"


def test_delete_nonexistent_key():
"""Test deleting nonexistent key doesn't raise error."""
state = AgentState()
state.delete("nonexistent") # Should not raise


def test_json_serializable_values():
"""Test that only JSON-serializable values are accepted."""
state = AgentState()

# Valid JSON types
state.set("string", "test")
state.set("int", 42)
state.set("bool", True)
state.set("list", [1, 2, 3])
state.set("dict", {"nested": "value"})
state.set("null", None)

# Invalid JSON types should raise ValueError
with pytest.raises(ValueError, match="not JSON serializable"):
state.set("function", lambda x: x)

with pytest.raises(ValueError, match="not JSON serializable"):
state.set("object", object())


def test_key_validation():
"""Test key validation for set and delete operations."""
state = AgentState()

# Invalid keys for set
with pytest.raises(ValueError, match="Key cannot be None"):
state.set(None, "value")

with pytest.raises(ValueError, match="Key cannot be empty"):
state.set("", "value")

with pytest.raises(ValueError, match="Key must be a string"):
state.set(123, "value")

# Invalid keys for delete
with pytest.raises(ValueError, match="Key cannot be None"):
state.delete(None)

with pytest.raises(ValueError, match="Key cannot be empty"):
state.delete("")


def test_initial_state():
"""Test initialization with initial state."""
initial = {"key1": "value1", "key2": "value2"}
state = AgentState(initial_state=initial)

assert state.get("key1") == "value1"
assert state.get("key2") == "value2"
assert state.get() == initial
Empty file.
73 changes: 73 additions & 0 deletions tests/strands/mocked_model_provider/mocked_model_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import json
from typing import Any, Callable, Iterable, Optional, Type, TypeVar

from pydantic import BaseModel

from strands.types.content import Message, Messages
from strands.types.event_loop import StopReason
from strands.types.models.model import Model
from strands.types.streaming import StreamEvent
from strands.types.tools import ToolSpec

T = TypeVar("T", bound=BaseModel)


class MockedModelProvider(Model):
"""A mock implementation of the Model interface for testing purposes.

This class simulates a model provider by returning pre-defined agent responses
in sequence. It implements the Model interface methods and provides functionality
to stream mock responses as events.
"""

def __init__(self, agent_responses: Messages):
self.agent_responses = agent_responses
self.index = 0

def format_chunk(self, event: Any) -> StreamEvent:
return event

def format_request(
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
) -> Any:
return None

def get_config(self) -> Any:
pass

def update_config(self, **model_config: Any) -> None:
pass

def structured_output(
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
) -> T:
pass

def stream(self, request: Any) -> Iterable[Any]:
yield from self.map_agent_message_to_events(self.agent_responses[self.index])
self.index += 1

def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]:
stop_reason: StopReason = "end_turn"
yield {"messageStart": {"role": "assistant"}}
for content in agent_message["content"]:
if "text" in content:
yield {"contentBlockStart": {"start": {}}}
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
yield {"contentBlockStop": {}}
if "toolUse" in content:
stop_reason = "tool_use"
yield {
"contentBlockStart": {
"start": {
"toolUse": {
"name": content["toolUse"]["name"],
"toolUseId": content["toolUse"]["toolUseId"],
}
}
}
}
yield {"contentBlockDelta": {"delta": {"tool_use": {"input": json.dumps(content["toolUse"]["input"])}}}}
yield {"contentBlockStop": {}}

yield {"messageStop": {"stopReason": stop_reason}}
36 changes: 36 additions & 0 deletions tests/strands/mocked_model_provider/test_agent_state_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from strands.agent.agent import Agent
from strands.tools.decorator import tool
from strands.types.content import Messages

from .mocked_model_provider import MockedModelProvider


@tool
def update_state(agent: Agent):
agent.state.set("hello", "world")
agent.state.set("foo", "baz")


def test_agent_state_update_from_tool():
agent_messages: Messages = [
{
"role": "assistant",
"content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}],
},
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
]
mocked_model_provider = MockedModelProvider(agent_messages)

agent = Agent(
model=mocked_model_provider,
tools=[update_state],
state={"foo": "bar"},
)

assert agent.state.get("hello") is None
assert agent.state.get("foo") == "bar"

agent("Invoke Mocked!")

assert agent.state.get("hello") == "world"
assert agent.state.get("foo") == "baz"