Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 83 additions & 4 deletions agent-memory-client/agent_memory_client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
For full model definitions, see the main agent_memory_server package.
"""

from datetime import datetime, timezone
import logging
import threading
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Literal
from typing import Any, ClassVar, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from ulid import ULID

logger = logging.getLogger(__name__)

# Model name literals for model-specific window sizes
ModelNameLiteral = Literal[
"gpt-3.5-turbo",
Expand Down Expand Up @@ -62,6 +66,15 @@ class MemoryStrategyConfig(BaseModel):
class MemoryMessage(BaseModel):
"""A message in the memory system"""

# Track message IDs that have been warned (in-memory, per-process)
# Used to rate-limit deprecation warnings
_warned_message_ids: ClassVar[set[str]] = set()
_warned_message_ids_lock: ClassVar[threading.Lock] = threading.Lock()
_max_warned_ids: ClassVar[int] = 10000 # Prevent unbounded growth

# Default tolerance for future timestamp validation (5 minutes)
_max_future_seconds: ClassVar[int] = 300

role: str
content: str
id: str = Field(
Expand All @@ -70,7 +83,7 @@ class MemoryMessage(BaseModel):
)
created_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="Timestamp when the message was created",
description="Timestamp when the message was created (should be provided by client)",
)
persisted_at: datetime | None = Field(
default=None,
Expand All @@ -81,6 +94,72 @@ class MemoryMessage(BaseModel):
description="Whether memory extraction has run for this message",
)

@model_validator(mode="before")
@classmethod
def validate_created_at(cls, data: Any) -> Any:
"""
Validate created_at timestamp:
- Warn if not provided by client (will become required in future version)
- Error if timestamp is in the future (beyond tolerance)
"""
if not isinstance(data, dict):
return data

created_at_provided = "created_at" in data and data["created_at"] is not None

if not created_at_provided:
# Rate-limit warnings by message ID (thread-safe)
msg_id = data.get("id", "unknown")

with cls._warned_message_ids_lock:
if msg_id not in cls._warned_message_ids:
# Prevent unbounded memory growth
if len(cls._warned_message_ids) >= cls._max_warned_ids:
cls._warned_message_ids.clear()
cls._warned_message_ids.add(msg_id)
should_warn = True
else:
should_warn = False

if should_warn:
logger.warning(
"MemoryMessage created without explicit created_at timestamp. "
"This will become required in a future version. "
"Please provide created_at for accurate message ordering."
)
else:
# Validate that created_at is not in the future
created_at_value = data["created_at"]

# Parse string to datetime if needed
if isinstance(created_at_value, str):
try:
# Handle ISO format with Z suffix
created_at_value = datetime.fromisoformat(
created_at_value.replace("Z", "+00:00")
)
except ValueError:
# Let Pydantic handle the parsing error
return data

if isinstance(created_at_value, datetime):
# Ensure timezone-aware comparison
now = datetime.now(timezone.utc)
if created_at_value.tzinfo is None:
# Assume UTC for naive datetimes
created_at_value = created_at_value.replace(tzinfo=timezone.utc)

max_allowed = now + timedelta(seconds=cls._max_future_seconds)

if created_at_value > max_allowed:
raise ValueError(
f"created_at cannot be more than {cls._max_future_seconds} seconds in the future. "
f"Received: {created_at_value.isoformat()}, "
f"Max allowed: {max_allowed.isoformat()}"
)

return data


class MemoryRecord(BaseModel):
"""A memory record"""
Expand Down
88 changes: 66 additions & 22 deletions agent_memory_server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any

import tiktoken
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response
from mcp.server.fastmcp.prompts import base
from mcp.types import TextContent

Expand Down Expand Up @@ -452,39 +452,28 @@ async def get_working_memory(
return WorkingMemoryResponse(**working_mem_data)


@router.put("/v1/working-memory/{session_id}", response_model=WorkingMemoryResponse)
async def put_working_memory(
async def put_working_memory_core(
session_id: str,
memory: UpdateWorkingMemory,
background_tasks: HybridBackgroundTasks,
model_name: ModelNameLiteral | None = None,
context_window_max: int | None = None,
current_user: UserInfo = Depends(get_current_user),
):
) -> WorkingMemoryResponse:
"""
Set working memory for a session. Replaces existing working memory.

The session_id comes from the URL path, not the request body.
If the token count exceeds the context window threshold, messages will be summarized
immediately and the updated memory state returned to the client.
Core implementation of put_working_memory.

NOTE on context_percentage_* fields:
The response includes `context_percentage_total_used` and `context_percentage_until_summarization`
fields that show token usage. These fields will be `null` unless you provide either:
- `model_name` query parameter (e.g., `?model_name=gpt-4o-mini`)
- `context_window_max` query parameter (e.g., `?context_window_max=500`)
This function contains the business logic for setting working memory and can be
called from both the REST API endpoint and MCP tools.

Args:
session_id: The session ID (from URL path)
memory: Working memory data to save (session_id not required in body)
session_id: The session ID
memory: Working memory data to save
background_tasks: Background tasks handler
model_name: The client's LLM model name for context window determination
context_window_max: Direct specification of context window max tokens (overrides model_name)
background_tasks: DocketBackgroundTasks instance (injected automatically)
context_window_max: Direct specification of context window max tokens

Returns:
Updated working memory (potentially with summary if tokens were condensed).
Includes context_percentage_total_used and context_percentage_until_summarization
if model information is provided.
Updated working memory response
"""
redis = await get_redis_conn()

Expand Down Expand Up @@ -557,6 +546,61 @@ async def put_working_memory(
return WorkingMemoryResponse(**updated_memory_data)


@router.put("/v1/working-memory/{session_id}", response_model=WorkingMemoryResponse)
async def put_working_memory(
session_id: str,
memory: UpdateWorkingMemory,
background_tasks: HybridBackgroundTasks,
response: Response,
model_name: ModelNameLiteral | None = None,
context_window_max: int | None = None,
current_user: UserInfo = Depends(get_current_user),
):
"""
Set working memory for a session. Replaces existing working memory.

The session_id comes from the URL path, not the request body.
If the token count exceeds the context window threshold, messages will be summarized
immediately and the updated memory state returned to the client.

NOTE on context_percentage_* fields:
The response includes `context_percentage_total_used` and `context_percentage_until_summarization`
fields that show token usage. These fields will be `null` unless you provide either:
- `model_name` query parameter (e.g., `?model_name=gpt-4o-mini`)
- `context_window_max` query parameter (e.g., `?context_window_max=500`)

Args:
session_id: The session ID (from URL path)
memory: Working memory data to save (session_id not required in body)
model_name: The client's LLM model name for context window determination
context_window_max: Direct specification of context window max tokens (overrides model_name)
background_tasks: DocketBackgroundTasks instance (injected automatically)
response: FastAPI Response object for setting headers

Returns:
Updated working memory (potentially with summary if tokens were condensed).
Includes context_percentage_total_used and context_percentage_until_summarization
if model information is provided.
"""
# Check if any messages are missing created_at timestamps and add deprecation header
messages_missing_timestamp = any(
not getattr(msg, "_created_at_was_provided", True) for msg in memory.messages
)
if messages_missing_timestamp:
response.headers["X-Deprecation-Warning"] = (
"messages[].created_at will become required in the next major version. "
"Please provide timestamps for all messages."
)

return await put_working_memory_core(
session_id=session_id,
memory=memory,
background_tasks=background_tasks,
model_name=model_name,
context_window_max=context_window_max,
)


@router.delete("/v1/working-memory/{session_id}", response_model=AckResponse)
async def delete_working_memory(
session_id: str,
Expand Down
7 changes: 7 additions & 0 deletions agent_memory_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,13 @@ class Settings(BaseSettings):
0.7 # Fraction of context window that triggers summarization
)

# Message timestamp validation settings
# If true, reject messages without created_at timestamp.
# If false (default), auto-generate timestamp with deprecation warning.
require_message_timestamps: bool = False
# Maximum allowed clock skew for future timestamp validation (in seconds)
max_future_timestamp_seconds: int = 300 # 5 minutes

# Working memory migration settings
# Set to True to skip backward compatibility checks for old string-format keys.
# Use this after running 'agent-memory migrate-working-memory' or for fresh installs.
Expand Down
2 changes: 1 addition & 1 deletion agent_memory_server/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_long_term_memory as core_get_long_term_memory,
get_working_memory as core_get_working_memory,
memory_prompt as core_memory_prompt,
put_working_memory as core_put_working_memory,
put_working_memory_core as core_put_working_memory,
search_long_term_memory as core_search_long_term_memory,
update_long_term_memory as core_update_long_term_memory,
)
Expand Down
Loading
Loading