-
Notifications
You must be signed in to change notification settings - Fork 228
feat: Agent State #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: Agent State #292
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
d0f6276
feat: Add Agent State
Unshure 49f0f49
Update state.py
Unshure 8ac6cd4
Allow dict input for state
Unshure 9862adb
Update src/strands/agent/agent.py
Unshure 56f53f3
fix: deepcopy AgentState
Unshure 4bfddbb
Update test_agent.py with comments
Unshure File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Agent state management.""" | ||
|
||
import json | ||
from typing import Any, Dict, Optional | ||
|
||
|
||
class AgentState: | ||
"""Represents an Agent's stateful information outside of context provided to a model. | ||
|
||
Provides a key-value store for agent state with JSON serialization validation and persistence support. | ||
Key features: | ||
- JSON serialization validation on assignment | ||
- Get/set/delete operations | ||
""" | ||
|
||
def __init__(self, initial_state: Optional[Dict[str, Dict[str, Any]]] = None): | ||
"""Initialize AgentState.""" | ||
self._state: Dict[str, Dict[str, Any]] | ||
if initial_state: | ||
self._validate_json_serializable(initial_state) | ||
self._state = initial_state.copy() | ||
Unshure marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
self._state = {} | ||
|
||
def set(self, key: str, value: Any) -> None: | ||
"""Set a value in the state. | ||
|
||
Args: | ||
key: The key to store the value under | ||
value: The value to store (must be JSON serializable) | ||
|
||
Raises: | ||
ValueError: If key is invalid, or if value is not JSON serializable | ||
""" | ||
self._validate_key(key) | ||
self._validate_json_serializable(value) | ||
|
||
self._state[key] = value | ||
|
||
def get(self, key: Optional[str] = None) -> Any: | ||
"""Get a value or entire state. | ||
|
||
Args: | ||
key: The key to retrieve (if None, returns entire state object) | ||
|
||
Returns: | ||
The stored value, entire state dict, or None if not found | ||
""" | ||
if key is None: | ||
return self._state.copy() | ||
Unshure marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
# Return specific key | ||
return self._state.get(key) | ||
|
||
def delete(self, key: str) -> None: | ||
"""Delete a specific key from the state. | ||
|
||
Args: | ||
key: The key to delete | ||
""" | ||
self._validate_key(key) | ||
|
||
self._state.pop(key, None) | ||
|
||
def _validate_key(self, key: str) -> None: | ||
"""Validate that a key is valid. | ||
|
||
Args: | ||
key: The key to validate | ||
|
||
Raises: | ||
ValueError: If key is invalid | ||
""" | ||
if key is None: | ||
raise ValueError("Key cannot be None") | ||
if not isinstance(key, str): | ||
raise ValueError("Key must be a string") | ||
if not key.strip(): | ||
raise ValueError("Key cannot be empty") | ||
|
||
def _validate_json_serializable(self, value: Any) -> None: | ||
"""Validate that a value is JSON serializable. | ||
|
||
Args: | ||
value: The value to validate | ||
|
||
Raises: | ||
ValueError: If value is not JSON serializable | ||
""" | ||
try: | ||
json.dumps(value) | ||
except (TypeError, ValueError) as e: | ||
raise ValueError( | ||
f"Value is not JSON serializable: {type(value).__name__}. " | ||
f"Only JSON-compatible types (str, int, float, bool, list, dict, None) are allowed." | ||
) from e |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
"""Tests for AgentState class.""" | ||
|
||
import pytest | ||
|
||
from strands.agent.state import AgentState | ||
|
||
|
||
def test_set_and_get(): | ||
"""Test basic set and get operations.""" | ||
state = AgentState() | ||
state.set("key", "value") | ||
assert state.get("key") == "value" | ||
|
||
|
||
def test_get_nonexistent_key(): | ||
"""Test getting nonexistent key returns None.""" | ||
state = AgentState() | ||
assert state.get("nonexistent") is None | ||
|
||
|
||
def test_get_entire_state(): | ||
"""Test getting entire state when no key specified.""" | ||
state = AgentState() | ||
state.set("key1", "value1") | ||
state.set("key2", "value2") | ||
|
||
result = state.get() | ||
assert result == {"key1": "value1", "key2": "value2"} | ||
|
||
|
||
def test_initialize_and_get_entire_state(): | ||
"""Test getting entire state when no key specified.""" | ||
state = AgentState({"key1": "value1", "key2": "value2"}) | ||
|
||
result = state.get() | ||
assert result == {"key1": "value1", "key2": "value2"} | ||
|
||
|
||
def test_initialize_with_error(): | ||
with pytest.raises(ValueError, match="not JSON serializable"): | ||
AgentState({"object", object()}) | ||
|
||
|
||
def test_delete(): | ||
"""Test deleting keys.""" | ||
state = AgentState() | ||
state.set("key1", "value1") | ||
state.set("key2", "value2") | ||
|
||
state.delete("key1") | ||
|
||
assert state.get("key1") is None | ||
assert state.get("key2") == "value2" | ||
|
||
|
||
def test_delete_nonexistent_key(): | ||
"""Test deleting nonexistent key doesn't raise error.""" | ||
state = AgentState() | ||
state.delete("nonexistent") # Should not raise | ||
|
||
|
||
def test_json_serializable_values(): | ||
"""Test that only JSON-serializable values are accepted.""" | ||
state = AgentState() | ||
|
||
# Valid JSON types | ||
state.set("string", "test") | ||
state.set("int", 42) | ||
state.set("bool", True) | ||
state.set("list", [1, 2, 3]) | ||
state.set("dict", {"nested": "value"}) | ||
state.set("null", None) | ||
|
||
# Invalid JSON types should raise ValueError | ||
with pytest.raises(ValueError, match="not JSON serializable"): | ||
state.set("function", lambda x: x) | ||
|
||
with pytest.raises(ValueError, match="not JSON serializable"): | ||
state.set("object", object()) | ||
|
||
|
||
def test_key_validation(): | ||
"""Test key validation for set and delete operations.""" | ||
state = AgentState() | ||
|
||
# Invalid keys for set | ||
with pytest.raises(ValueError, match="Key cannot be None"): | ||
state.set(None, "value") | ||
|
||
with pytest.raises(ValueError, match="Key cannot be empty"): | ||
state.set("", "value") | ||
|
||
with pytest.raises(ValueError, match="Key must be a string"): | ||
state.set(123, "value") | ||
|
||
# Invalid keys for delete | ||
with pytest.raises(ValueError, match="Key cannot be None"): | ||
state.delete(None) | ||
|
||
with pytest.raises(ValueError, match="Key cannot be empty"): | ||
state.delete("") | ||
|
||
|
||
def test_initial_state(): | ||
"""Test initialization with initial state.""" | ||
initial = {"key1": "value1", "key2": "value2"} | ||
state = AgentState(initial_state=initial) | ||
|
||
assert state.get("key1") == "value1" | ||
assert state.get("key2") == "value2" | ||
assert state.get() == initial |
Empty file.
73 changes: 73 additions & 0 deletions
73
tests/strands/mocked_model_provider/mocked_model_provider.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import json | ||
from typing import Any, Callable, Iterable, Optional, Type, TypeVar | ||
|
||
from pydantic import BaseModel | ||
|
||
from strands.types.content import Message, Messages | ||
from strands.types.event_loop import StopReason | ||
from strands.types.models.model import Model | ||
from strands.types.streaming import StreamEvent | ||
from strands.types.tools import ToolSpec | ||
|
||
T = TypeVar("T", bound=BaseModel) | ||
|
||
|
||
class MockedModelProvider(Model): | ||
"""A mock implementation of the Model interface for testing purposes. | ||
|
||
This class simulates a model provider by returning pre-defined agent responses | ||
in sequence. It implements the Model interface methods and provides functionality | ||
to stream mock responses as events. | ||
""" | ||
|
||
def __init__(self, agent_responses: Messages): | ||
self.agent_responses = agent_responses | ||
self.index = 0 | ||
|
||
def format_chunk(self, event: Any) -> StreamEvent: | ||
return event | ||
|
||
def format_request( | ||
self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None | ||
) -> Any: | ||
return None | ||
|
||
def get_config(self) -> Any: | ||
pass | ||
|
||
def update_config(self, **model_config: Any) -> None: | ||
pass | ||
|
||
def structured_output( | ||
self, output_model: Type[T], prompt: Messages, callback_handler: Optional[Callable] = None | ||
) -> T: | ||
pass | ||
|
||
def stream(self, request: Any) -> Iterable[Any]: | ||
yield from self.map_agent_message_to_events(self.agent_responses[self.index]) | ||
self.index += 1 | ||
|
||
def map_agent_message_to_events(self, agent_message: Message) -> Iterable[dict[str, Any]]: | ||
stop_reason: StopReason = "end_turn" | ||
yield {"messageStart": {"role": "assistant"}} | ||
for content in agent_message["content"]: | ||
if "text" in content: | ||
yield {"contentBlockStart": {"start": {}}} | ||
yield {"contentBlockDelta": {"delta": {"text": content["text"]}}} | ||
yield {"contentBlockStop": {}} | ||
if "toolUse" in content: | ||
stop_reason = "tool_use" | ||
yield { | ||
"contentBlockStart": { | ||
"start": { | ||
"toolUse": { | ||
"name": content["toolUse"]["name"], | ||
"toolUseId": content["toolUse"]["toolUseId"], | ||
} | ||
} | ||
} | ||
} | ||
yield {"contentBlockDelta": {"delta": {"tool_use": {"input": json.dumps(content["toolUse"]["input"])}}}} | ||
yield {"contentBlockStop": {}} | ||
|
||
yield {"messageStop": {"stopReason": stop_reason}} |
36 changes: 36 additions & 0 deletions
36
tests/strands/mocked_model_provider/test_agent_state_updates.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from strands.agent.agent import Agent | ||
from strands.tools.decorator import tool | ||
from strands.types.content import Messages | ||
|
||
from .mocked_model_provider import MockedModelProvider | ||
|
||
|
||
@tool | ||
def update_state(agent: Agent): | ||
agent.state.set("hello", "world") | ||
agent.state.set("foo", "baz") | ||
|
||
|
||
def test_agent_state_update_from_tool(): | ||
agent_messages: Messages = [ | ||
{ | ||
"role": "assistant", | ||
"content": [{"toolUse": {"name": "update_state", "toolUseId": "123", "input": {}}}], | ||
}, | ||
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]}, | ||
] | ||
mocked_model_provider = MockedModelProvider(agent_messages) | ||
|
||
agent = Agent( | ||
model=mocked_model_provider, | ||
tools=[update_state], | ||
state={"foo": "bar"}, | ||
) | ||
|
||
assert agent.state.get("hello") is None | ||
assert agent.state.get("foo") == "bar" | ||
|
||
agent("Invoke Mocked!") | ||
|
||
assert agent.state.get("hello") == "world" | ||
assert agent.state.get("foo") == "baz" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.