Skip to content

Commit 56f53f3

Browse files
committed
fix: deepcopy AgentState
1 parent 9862adb commit 56f53f3

File tree

2 files changed

+41
-5
lines changed

2 files changed

+41
-5
lines changed

src/strands/agent/state.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Agent state management."""
22

3+
import copy
34
import json
45
from typing import Any, Dict, Optional
56

@@ -13,12 +14,12 @@ class AgentState:
1314
- Get/set/delete operations
1415
"""
1516

16-
def __init__(self, initial_state: Optional[Dict[str, Dict[str, Any]]] = None):
17+
def __init__(self, initial_state: Optional[Dict[str, Any]] = None):
1718
"""Initialize AgentState."""
1819
self._state: Dict[str, Dict[str, Any]]
1920
if initial_state:
2021
self._validate_json_serializable(initial_state)
21-
self._state = initial_state.copy()
22+
self._state = copy.deepcopy(initial_state)
2223
else:
2324
self._state = {}
2425

@@ -35,7 +36,7 @@ def set(self, key: str, value: Any) -> None:
3536
self._validate_key(key)
3637
self._validate_json_serializable(value)
3738

38-
self._state[key] = value
39+
self._state[key] = copy.deepcopy(value)
3940

4041
def get(self, key: Optional[str] = None) -> Any:
4142
"""Get a value or entire state.
@@ -47,10 +48,10 @@ def get(self, key: Optional[str] = None) -> Any:
4748
The stored value, entire state dict, or None if not found
4849
"""
4950
if key is None:
50-
return self._state.copy()
51+
return copy.deepcopy(self._state)
5152
else:
5253
# Return specific key
53-
return self._state.get(key)
54+
return copy.deepcopy(self._state.get(key))
5455

5556
def delete(self, key: str) -> None:
5657
"""Delete a specific key from the state.

tests/strands/agent/test_agent.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import importlib
3+
import json
34
import os
45
import textwrap
56
import threading
@@ -1326,3 +1327,37 @@ def test_non_json_serializable_state_throws_error():
13261327
with pytest.raises(ValueError, match="Value is not JSON serializable"):
13271328
agent = Agent(state={"object": object()})
13281329
print(agent.state)
1330+
1331+
1332+
def test_agent_state_breaks_dict_reference():
1333+
ref_dict = {"hello": "world"}
1334+
agent = Agent(state=ref_dict)
1335+
ref_dict["hello"] = object()
1336+
1337+
json.dumps(agent.state.get())
1338+
1339+
1340+
def test_agent_state_breaks_deep_dict_reference():
1341+
ref_dict = {"world": "!"}
1342+
init_dict = {"hello": ref_dict}
1343+
agent = Agent(state=init_dict)
1344+
ref_dict["world"] = object()
1345+
1346+
json.dumps(agent.state.get())
1347+
1348+
1349+
def test_agent_state_set_breaks_dict_reference():
1350+
agent = Agent()
1351+
ref_dict = {"hello": "world"}
1352+
agent.state.set("hello", ref_dict)
1353+
ref_dict["hello"] = object()
1354+
1355+
json.dumps(agent.state.get())
1356+
1357+
1358+
def test_agent_state_get_breaks_deep_dict_reference():
1359+
agent = Agent(state={"hello": {"world": "!"}})
1360+
ref_state = agent.state.get()
1361+
ref_state["hello"]["world"] = object()
1362+
1363+
json.dumps(agent.state.get())

0 commit comments

Comments
 (0)