Skip to content

Commit 8ac6cd4

Browse files
committed
Allow dict input for state
1 parent 49f0f49 commit 8ac6cd4

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

src/strands/agent/agent.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
*,
225225
name: Optional[str] = None,
226226
description: Optional[str] = None,
227-
state: Optional[AgentState] = None,
227+
state: Optional[Union[AgentState, dict]] = None,
228228
):
229229
"""Initialize the Agent with the specified configuration.
230230
@@ -261,7 +261,7 @@ def __init__(
261261
Defaults to None.
262262
description: description of what the Agent does
263263
Defaults to None.
264-
state: stateful information for the agent
264+
state: stateful information for the agent. Can be either an AgentState object, or a json serializable dict.
265265
Defaults to an empty AgentState object.
266266
267267
Raises:
@@ -325,7 +325,17 @@ def __init__(
325325
self.trace_span: Optional[trace.Span] = None
326326

327327
# Initialize agent state management
328-
self.state = state or AgentState()
328+
if state is not None:
329+
if isinstance(state, dict):
330+
self.state = AgentState(state)
331+
elif isinstance(state, AgentState):
332+
print("HERE!")
333+
print(type(state))
334+
self.state = state
335+
else:
336+
raise ValueError("state must be an AgentState object or a dict")
337+
else:
338+
self.state = AgentState()
329339

330340
self.tool_caller = Agent.ToolCaller(self)
331341
self.name = name

tests/strands/agent/test_agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,3 +1314,15 @@ def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_
13141314
kwargs = mock_event_loop_cycle.call_args[1]
13151315
assert "event_loop_parent_span" in kwargs
13161316
assert kwargs["event_loop_parent_span"] == mock_span
1317+
1318+
1319+
def test_non_dict_throws_error():
1320+
with pytest.raises(ValueError, match="state must be an AgentState object or a dict"):
1321+
agent = Agent(state={"object", object()})
1322+
print(agent.state)
1323+
1324+
1325+
def test_non_json_serializable_state_throws_error():
1326+
with pytest.raises(ValueError, match="Value is not JSON serializable"):
1327+
agent = Agent(state={"object": object()})
1328+
print(agent.state)

tests/strands/mocked_model_provider/test_agent_state_updates.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
@tool
99
def update_state(agent: Agent):
1010
agent.state.set("hello", "world")
11+
agent.state.set("foo", "baz")
1112

1213

1314
def test_agent_state_update_from_tool():
@@ -20,10 +21,16 @@ def test_agent_state_update_from_tool():
2021
]
2122
mocked_model_provider = MockedModelProvider(agent_messages)
2223

23-
agent = Agent(model=mocked_model_provider, tools=[update_state])
24+
agent = Agent(
25+
model=mocked_model_provider,
26+
tools=[update_state],
27+
state={"foo": "bar"},
28+
)
2429

2530
assert agent.state.get("hello") is None
31+
assert agent.state.get("foo") == "bar"
2632

2733
agent("Invoke Mocked!")
2834

2935
assert agent.state.get("hello") == "world"
36+
assert agent.state.get("foo") == "baz"

0 commit comments

Comments
 (0)