Skip to content

Commit bd36b95

Browse files
Unshurezastrowm
andauthored
feat: Agent State (#292)
* feat: Add Agent State * Update state.py * Allow dict input for state * Update src/strands/agent/agent.py Co-authored-by: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> * fix: deepcopy AgentState * Update test_agent.py with comments --------- Co-authored-by: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com>
1 parent 66b4aef commit bd36b95

File tree

7 files changed

+389
-0
lines changed

7 files changed

+389
-0
lines changed

src/strands/agent/agent.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ConversationManager,
3939
SlidingWindowConversationManager,
4040
)
41+
from .state import AgentState
4142

4243
logger = logging.getLogger(__name__)
4344

@@ -193,6 +194,7 @@ def __init__(
193194
*,
194195
name: Optional[str] = None,
195196
description: Optional[str] = None,
197+
state: Optional[Union[AgentState, dict]] = None,
196198
):
197199
"""Initialize the Agent with the specified configuration.
198200
@@ -229,6 +231,8 @@ def __init__(
229231
Defaults to None.
230232
description: description of what the Agent does
231233
Defaults to None.
234+
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
235+
Defaults to an empty AgentState object.
232236
233237
Raises:
234238
ValueError: If max_parallel_tools is less than 1.
@@ -289,6 +293,18 @@ def __init__(
289293
# Initialize tracer instance (no-op if not configured)
290294
self.tracer = get_tracer()
291295
self.trace_span: Optional[trace.Span] = None
296+
297+
# Initialize agent state management
298+
if state is not None:
299+
if isinstance(state, dict):
300+
self.state = AgentState(state)
301+
elif isinstance(state, AgentState):
302+
self.state = state
303+
else:
304+
raise ValueError("state must be an AgentState object or a dict")
305+
else:
306+
self.state = AgentState()
307+
292308
self.tool_caller = Agent.ToolCaller(self)
293309
self.name = name
294310
self.description = description

src/strands/agent/state.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Agent state management."""
2+
3+
import copy
4+
import json
5+
from typing import Any, Dict, Optional
6+
7+
8+
class AgentState:
9+
"""Represents an Agent's stateful information outside of context provided to a model.
10+
11+
Provides a key-value store for agent state with JSON serialization validation and persistence support.
12+
Key features:
13+
- JSON serialization validation on assignment
14+
- Get/set/delete operations
15+
"""
16+
17+
def __init__(self, initial_state: Optional[Dict[str, Any]] = None):
18+
"""Initialize AgentState."""
19+
self._state: Dict[str, Dict[str, Any]]
20+
if initial_state:
21+
self._validate_json_serializable(initial_state)
22+
self._state = copy.deepcopy(initial_state)
23+
else:
24+
self._state = {}
25+
26+
def set(self, key: str, value: Any) -> None:
27+
"""Set a value in the state.
28+
29+
Args:
30+
key: The key to store the value under
31+
value: The value to store (must be JSON serializable)
32+
33+
Raises:
34+
ValueError: If key is invalid, or if value is not JSON serializable
35+
"""
36+
self._validate_key(key)
37+
self._validate_json_serializable(value)
38+
39+
self._state[key] = copy.deepcopy(value)
40+
41+
def get(self, key: Optional[str] = None) -> Any:
42+
"""Get a value or entire state.
43+
44+
Args:
45+
key: The key to retrieve (if None, returns entire state object)
46+
47+
Returns:
48+
The stored value, entire state dict, or None if not found
49+
"""
50+
if key is None:
51+
return copy.deepcopy(self._state)
52+
else:
53+
# Return specific key
54+
return copy.deepcopy(self._state.get(key))
55+
56+
def delete(self, key: str) -> None:
57+
"""Delete a specific key from the state.
58+
59+
Args:
60+
key: The key to delete
61+
"""
62+
self._validate_key(key)
63+
64+
self._state.pop(key, None)
65+
66+
def _validate_key(self, key: str) -> None:
67+
"""Validate that a key is valid.
68+
69+
Args:
70+
key: The key to validate
71+
72+
Raises:
73+
ValueError: If key is invalid
74+
"""
75+
if key is None:
76+
raise ValueError("Key cannot be None")
77+
if not isinstance(key, str):
78+
raise ValueError("Key must be a string")
79+
if not key.strip():
80+
raise ValueError("Key cannot be empty")
81+
82+
def _validate_json_serializable(self, value: Any) -> None:
83+
"""Validate that a value is JSON serializable.
84+
85+
Args:
86+
value: The value to validate
87+
88+
Raises:
89+
ValueError: If value is not JSON serializable
90+
"""
91+
try:
92+
json.dumps(value)
93+
except (TypeError, ValueError) as e:
94+
raise ValueError(
95+
f"Value is not JSON serializable: {type(value).__name__}. "
96+
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed."
97+
) from e

tests/strands/agent/test_agent.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import importlib
3+
import json
34
import os
45
import textwrap
56
import unittest.mock
@@ -1203,3 +1204,58 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_
12031204
kwargs = mock_event_loop_cycle.call_args[1]
12041205
assert "event_loop_parent_span" in kwargs
12051206
assert kwargs["event_loop_parent_span"] == mock_span
1207+
1208+
1209+
def test_non_dict_throws_error():
1210+
with pytest.raises(ValueError, match="state must be an AgentState object or a dict"):
1211+
agent = Agent(state={"object", object()})
1212+
print(agent.state)
1213+
1214+
1215+
def test_non_json_serializable_state_throws_error():
1216+
with pytest.raises(ValueError, match="Value is not JSON serializable"):
1217+
agent = Agent(state={"object": object()})
1218+
print(agent.state)
1219+
1220+
1221+
def test_agent_state_breaks_dict_reference():
1222+
ref_dict = {"hello": "world"}
1223+
agent = Agent(state=ref_dict)
1224+
1225+
# Make sure shallow object references do not affect state maintained by AgentState
1226+
ref_dict["hello"] = object()
1227+
1228+
# This will fail if AgentState reflects the updated reference
1229+
json.dumps(agent.state.get())
1230+
1231+
1232+
def test_agent_state_breaks_deep_dict_reference():
1233+
ref_dict = {"world": "!"}
1234+
init_dict = {"hello": ref_dict}
1235+
agent = Agent(state=init_dict)
1236+
# Make sure deep reference changes do not affect state mained by AgentState
1237+
ref_dict["world"] = object()
1238+
1239+
# This will fail if AgentState reflects the updated reference
1240+
json.dumps(agent.state.get())
1241+
1242+
1243+
def test_agent_state_set_breaks_dict_reference():
1244+
agent = Agent()
1245+
ref_dict = {"hello": "world"}
1246+
# Set should copy the input, and not maintain the reference to the original object
1247+
agent.state.set("hello", ref_dict)
1248+
ref_dict["hello"] = object()
1249+
1250+
# This will fail if AgentState reflects the updated reference
1251+
json.dumps(agent.state.get())
1252+
1253+
1254+
def test_agent_state_get_breaks_deep_dict_reference():
1255+
agent = Agent(state={"hello": {"world": "!"}})
1256+
# Get should not return a reference to the internal state
1257+
ref_state = agent.state.get()
1258+
ref_state["hello"]["world"] = object()
1259+
1260+
# This will fail if AgentState reflects the updated reference
1261+
json.dumps(agent.state.get())
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Tests for AgentState class."""
2+
3+
import pytest
4+
5+
from strands.agent.state import AgentState
6+
7+
8+
def test_set_and_get():
9+
"""Test basic set and get operations."""
10+
state = AgentState()
11+
state.set("key", "value")
12+
assert state.get("key") == "value"
13+
14+
15+
def test_get_nonexistent_key():
16+
"""Test getting nonexistent key returns None."""
17+
state = AgentState()
18+
assert state.get("nonexistent") is None
19+
20+
21+
def test_get_entire_state():
22+
"""Test getting entire state when no key specified."""
23+
state = AgentState()
24+
state.set("key1", "value1")
25+
state.set("key2", "value2")
26+
27+
result = state.get()
28+
assert result == {"key1": "value1", "key2": "value2"}
29+
30+
31+
def test_initialize_and_get_entire_state():
32+
"""Test getting entire state when no key specified."""
33+
state = AgentState({"key1": "value1", "key2": "value2"})
34+
35+
result = state.get()
36+
assert result == {"key1": "value1", "key2": "value2"}
37+
38+
39+
def test_initialize_with_error():
40+
with pytest.raises(ValueError, match="not JSON serializable"):
41+
AgentState({"object", object()})
42+
43+
44+
def test_delete():
45+
"""Test deleting keys."""
46+
state = AgentState()
47+
state.set("key1", "value1")
48+
state.set("key2", "value2")
49+
50+
state.delete("key1")
51+
52+
assert state.get("key1") is None
53+
assert state.get("key2") == "value2"
54+
55+
56+
def test_delete_nonexistent_key():
57+
"""Test deleting nonexistent key doesn't raise error."""
58+
state = AgentState()
59+
state.delete("nonexistent") # Should not raise
60+
61+
62+
def test_json_serializable_values():
63+
"""Test that only JSON-serializable values are accepted."""
64+
state = AgentState()
65+
66+
# Valid JSON types
67+
state.set("string", "test")
68+
state.set("int", 42)
69+
state.set("bool", True)
70+
state.set("list", [1, 2, 3])
71+
state.set("dict", {"nested": "value"})
72+
state.set("null", None)
73+
74+
# Invalid JSON types should raise ValueError
75+
with pytest.raises(ValueError, match="not JSON serializable"):
76+
state.set("function", lambda x: x)
77+
78+
with pytest.raises(ValueError, match="not JSON serializable"):
79+
state.set("object", object())
80+
81+
82+
def test_key_validation():
83+
"""Test key validation for set and delete operations."""
84+
state = AgentState()
85+
86+
# Invalid keys for set
87+
with pytest.raises(ValueError, match="Key cannot be None"):
88+
state.set(None, "value")
89+
90+
with pytest.raises(ValueError, match="Key cannot be empty"):
91+
state.set("", "value")
92+
93+
with pytest.raises(ValueError, match="Key must be a string"):
94+
state.set(123, "value")
95+
96+
# Invalid keys for delete
97+
with pytest.raises(ValueError, match="Key cannot be None"):
98+
state.delete(None)
99+
100+
with pytest.raises(ValueError, match="Key cannot be empty"):
101+
state.delete("")
102+
103+
104+
def test_initial_state():
105+
"""Test initialization with initial state."""
106+
initial = {"key1": "value1", "key2": "value2"}
107+
state = AgentState(initial_state=initial)
108+
109+
assert state.get("key1") == "value1"
110+
assert state.get("key2") == "value2"
111+
assert state.get() == initial

tests/strands/mocked_model_provider/__init__.py

Whitespace-only changes.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import json
2+
from typing import Any, Callable, Iterable, Optional, Type, TypeVar
3+
4+
from pydantic import BaseModel
5+
6+
from strands.types.content import Message, Messages
7+
from strands.types.event_loop import StopReason
8+
from strands.types.models.model import Model
9+
from strands.types.streaming import StreamEvent
10+
from strands.types.tools import ToolSpec
11+
12+
T = TypeVar("T", bound=BaseModel)
13+
14+
15+
class MockedModelProvider(Model):
16+
"""A mock implementation of the Model interface for testing purposes.
17+
18+
This class simulates a model provider by returning pre-defined agent responses
19+
in sequence. It implements the Model interface methods and provides functionality
20+
to stream mock responses as events.
21+
"""
22+
23+
def __init__(self, agent_responses: Messages):
24+
self.agent_responses = agent_responses
25+
self.index = 0
26+
27+
def format_chunk(self, event: Any) -> StreamEvent:
28+
return event
29+
30+
def format_request(
31+
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
32+
) -> Any:
33+
return None
34+
35+
def get_config(self) -> Any:
36+
pass
37+
38+
def update_config(self, **model_config: Any) -> None:
39+
pass
40+
41+
def structured_output(
42+
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None
43+
) -> T:
44+
pass
45+
46+
def stream(self, request: Any) -> Iterable[Any]:
47+
yield from self.map_agent_message_to_events(self.agent_responses[self.index])
48+
self.index += 1
49+
50+
def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]:
51+
stop_reason: StopReason = "end_turn"
52+
yield {"messageStart": {"role": "assistant"}}
53+
for content in agent_message["content"]:
54+
if "text" in content:
55+
yield {"contentBlockStart": {"start": {}}}
56+
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}}
57+
yield {"contentBlockStop": {}}
58+
if "toolUse" in content:
59+
stop_reason = "tool_use"
60+
yield {
61+
"contentBlockStart": {
62+
"start": {
63+
"toolUse": {
64+
"name": content["toolUse"]["name"],
65+
"toolUseId": content["toolUse"]["toolUseId"],
66+
}
67+
}
68+
}
69+
}
70+
yield {"contentBlockDelta": {"delta": {"tool_use": {"input": json.dumps(content["toolUse"]["input"])}}}}
71+
yield {"contentBlockStop": {}}
72+
73+
yield {"messageStop": {"stopReason": stop_reason}}

0 commit comments

Comments
 (0)