Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/memory/advanced_sqlite_session_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def main():
# Show current conversation
print("Current conversation:")
current_items = await session.get_items()
for i, item in enumerate(current_items, 1):
for i, item in enumerate(current_items, 1): # type: ignore[assignment]
role = str(item.get("role", item.get("type", "unknown")))
if item.get("type") == "function_call":
content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})"
Expand All @@ -151,8 +151,8 @@ async def main():
# Show available turns for branching
print("\nAvailable turns for branching:")
turns = await session.get_conversation_turns()
for turn in turns:
print(f" Turn {turn['turn']}: {turn['content']}")
for turn in turns: # type: ignore[assignment]
print(f" Turn {turn['turn']}: {turn['content']}") # type: ignore[index]

# Create a branch from turn 2
print("\nCreating new branch from turn 2...")
Expand All @@ -163,7 +163,7 @@ async def main():
branch_items = await session.get_items()
print(f"Items copied to new branch: {len(branch_items)}")
print("New branch contains:")
for i, item in enumerate(branch_items, 1):
for i, item in enumerate(branch_items, 1): # type: ignore[assignment]
role = str(item.get("role", item.get("type", "unknown")))
if item.get("type") == "function_call":
content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})"
Expand Down Expand Up @@ -198,7 +198,7 @@ async def main():
print("\n=== New Conversation Branch ===")
new_conversation = await session.get_items()
print("New conversation with branch:")
for i, item in enumerate(new_conversation, 1):
for i, item in enumerate(new_conversation, 1): # type: ignore[assignment]
role = str(item.get("role", item.get("type", "unknown")))
if item.get("type") == "function_call":
content = f"{item.get('name', 'unknown')}({item.get('arguments', '{}')})"
Expand All @@ -224,8 +224,8 @@ async def main():
# Show conversation turns in current branch
print("\nConversation turns in current branch:")
current_turns = await session.get_conversation_turns()
for turn in current_turns:
print(f" Turn {turn['turn']}: {turn['content']}")
for turn in current_turns: # type: ignore[assignment]
print(f" Turn {turn['turn']}: {turn['content']}") # type: ignore[index]

print("\n=== Branch Switching Demo ===")
print("We can switch back to the main branch...")
Expand Down
4 changes: 2 additions & 2 deletions examples/memory/dapr_session_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ async def demonstrate_multi_store():
r_items = await redis_session.get_items()
p_items = await pg_session.get_items()

r_example = r_items[-1]["content"] if r_items else "empty"
p_example = p_items[-1]["content"] if p_items else "empty"
r_example = r_items[-1]["content"] if r_items else "empty" # type: ignore[typeddict-item]
p_example = p_items[-1]["content"] if p_items else "empty" # type: ignore[typeddict-item]

print(f"{redis_store}: {len(r_items)} items; example: {r_example}")
print(f"{pg_store}: {len(p_items)} items; example: {p_example}")
Expand Down
13 changes: 12 additions & 1 deletion src/agents/extensions/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@

from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from .advanced_sqlite_session import AdvancedSQLiteSession
from .dapr_session import (
DAPR_CONSISTENCY_EVENTUAL,
DAPR_CONSISTENCY_STRONG,
DaprSession,
)
from .encrypt_session import EncryptedSession
from .redis_session import RedisSession
from .sqlalchemy_session import SQLAlchemySession

__all__: list[str] = [
"AdvancedSQLiteSession",
Expand Down
34 changes: 17 additions & 17 deletions tests/extensions/memory/test_dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def _create_test_session(
session = DaprSession(
session_id=session_id,
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)

# Clean up any existing data
Expand Down Expand Up @@ -260,12 +260,12 @@ async def test_session_isolation(fake_dapr_client: FakeDaprClient):
session1 = DaprSession(
session_id="session_1",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)
session2 = DaprSession(
session_id="session_2",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)

try:
Expand Down Expand Up @@ -386,7 +386,7 @@ async def test_pop_from_empty_session(fake_dapr_client: FakeDaprClient):
session = DaprSession(
session_id="empty_session",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)
try:
await session.clear_session()
Expand Down Expand Up @@ -540,7 +540,7 @@ async def test_dapr_connectivity(fake_dapr_client: FakeDaprClient):
session = DaprSession(
session_id="connectivity_test",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)
try:
# Test ping
Expand All @@ -555,7 +555,7 @@ async def test_ttl_functionality(fake_dapr_client: FakeDaprClient):
session = DaprSession(
session_id="ttl_test",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
ttl=3600, # 1 hour TTL
)

Expand Down Expand Up @@ -586,15 +586,15 @@ async def test_consistency_levels(fake_dapr_client: FakeDaprClient):
session_eventual = DaprSession(
session_id="eventual_test",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
consistency=DAPR_CONSISTENCY_EVENTUAL,
)

# Test strong consistency
session_strong = DaprSession(
session_id="strong_test",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
consistency=DAPR_CONSISTENCY_STRONG,
)

Expand All @@ -621,7 +621,7 @@ async def test_external_client_not_closed(fake_dapr_client: FakeDaprClient):
session = DaprSession(
session_id="external_client_test",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)

try:
Expand Down Expand Up @@ -650,7 +650,7 @@ async def test_internal_client_ownership(fake_dapr_client: FakeDaprClient):
session = DaprSession(
session_id="internal_client_test",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)
session._owns_client = True # Simulate ownership

Expand Down Expand Up @@ -732,7 +732,7 @@ async def test_close_method_coverage(fake_dapr_client: FakeDaprClient):
session1 = DaprSession(
session_id="close_test_1",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)

# Verify _owns_client is False for external client
Expand All @@ -749,7 +749,7 @@ async def test_close_method_coverage(fake_dapr_client: FakeDaprClient):
session2 = DaprSession(
session_id="close_test_2",
state_store_name="statestore",
dapr_client=fake_dapr_client2,
dapr_client=fake_dapr_client2, # type: ignore[arg-type]
)
session2._owns_client = True # Simulate ownership

Expand Down Expand Up @@ -788,8 +788,8 @@ async def test_already_deserialized_messages(fake_dapr_client: FakeDaprClient):
# Should handle both string and dict messages
items = await session.get_items()
assert len(items) == 2
assert items[0]["content"] == "First message"
assert items[1]["content"] == "Second message"
assert items[0]["content"] == "First message" # type: ignore[typeddict-item]
assert items[1]["content"] == "Second message" # type: ignore[typeddict-item]

await session.close()

Expand All @@ -800,7 +800,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient):
async with DaprSession(
"test_cm_session",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
) as session:
# Verify we got the session object back
assert session.session_id == "test_cm_session"
Expand All @@ -809,7 +809,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient):
await session.add_items([{"role": "user", "content": "Test message"}])
items = await session.get_items()
assert len(items) == 1
assert items[0]["content"] == "Test message"
assert items[0]["content"] == "Test message" # type: ignore[typeddict-item]

# After exiting context manager, close should have been called
# Verify we can still check the state (fake client doesn't truly disconnect)
Expand All @@ -819,7 +819,7 @@ async def test_context_manager(fake_dapr_client: FakeDaprClient):
owned_session = DaprSession(
"test_cm_owned",
state_store_name="statestore",
dapr_client=fake_dapr_client,
dapr_client=fake_dapr_client, # type: ignore[arg-type]
)
# Manually set ownership to simulate from_address behavior
owned_session._owns_client = True
Expand Down