Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions memory/cascading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Dict, Optional, Callable
import re
from typing import List, Dict, Optional, Callable, Tuple
import numpy as np
from .base import BaseMemory
from .decay import get_decay_fn, decay_ebbinghaus
Expand All @@ -9,16 +10,60 @@ def _extractive_summary(messages: List[Dict], max_chars: int = 400) -> str:
"""
Lightweight extractive summary: keep sentences that contain key=value patterns.
No LLM call needed — fast and cost-free.

Update messages ("changed to") are placed first so they survive truncation
and take precedence over initial injection lines for the same fact.
"""
lines = []
update_lines = []
injection_lines = []
for m in messages:
content = m.get("content", "")
if any(kw in content.lower() for kw in ["my ", "is ", "are ", "changed to", "name", "city", "age"]):
lines.append(f"{m['role']}: {content}")
if "changed to" in content.lower():
update_lines.append(f"{m['role']}: {content}")
elif any(kw in content.lower() for kw in ["my ", "is ", "are ", "name", "city", "age"]):
injection_lines.append(f"{m['role']}: {content}")

lines = update_lines + injection_lines
summary = " | ".join(lines)
return summary[:max_chars] if summary else "No key facts."


def _parse_update(content: str) -> Optional[Tuple[str, str]]:
"""
Parse "Actually, my <key_name> has changed to <new_value>."
Returns (key_name, new_value) or None.
"""
m = re.search(
r"my\s+(.+?)\s+has\s+changed\s+to\s+(.+?)[\.\!]?\s*$",
content,
re.IGNORECASE,
)
if m:
return m.group(1).strip(), m.group(2).strip()
return None


def _patch_cold_with_update(cold: List[str], key_name: str, new_value: str) -> List[str]:
"""
Replace any value associated with *key_name* in cold summary strings so that
stale original values are overwritten by the current value.

Targets the pattern produced by Fact.injection_text():
"my <key> is <old_value>" → "my <key> is <new_value>"
and also direct "changed to <old_new>" occurrences from earlier updates.
"""
# Match "my <key_name> is <anything up to a pipe/period/end>"
pattern = re.compile(
r"(my\s+" + re.escape(key_name) + r"\s+(?:is|are|was|has been)\s+)([^|.\n]+)",
re.IGNORECASE,
)
result = []
for entry in cold:
patched = pattern.sub(lambda m: m.group(1) + new_value, entry)
result.append(patched)
return result


class CascadingTemporalMemory(BaseMemory):
"""
Three-tier cascading memory with pluggable temporal decay.
Expand All @@ -30,6 +75,11 @@ class CascadingTemporalMemory(BaseMemory):

Decay options: 'ebbinghaus' (default) | 'exponential' | 'linear' | 'default'
Reference: Ebbinghaus, H. (1885). Über das Gedächtnis.

Fix (issue #2): when a fact-update message cascades from the warm tier into
cold, existing cold summaries are patched so the stale original value is
replaced by the new one. This eliminates the 100% temporal-drift regression
observed at T=100 where the old value was frozen inside compressed cold text.
"""

name = "cascading"
Expand All @@ -53,8 +103,20 @@ def __init__(
self.cold: List[str] = []
self.turn_count = 0

# fact_key → new_value for every update seen so far
self._fact_updates: Dict[str, str] = {}

def add_message(self, role: str, content: str, turn: int) -> None:
msg = {"role": role, "content": content, "turn": turn}

parsed = _parse_update(content)
if parsed:
key_name, new_val = parsed
self._fact_updates[key_name] = new_val
# Immediately patch warm messages that are already compressible
# (they haven't hit cold yet; cold is patched in _cascade_warm)
self.cold = _patch_cold_with_update(self.cold, key_name, new_val)

self.hot.append(msg)
self.turn_count += 1

Expand All @@ -80,8 +142,14 @@ def _cascade_warm(self) -> None:
summary = _extractive_summary(overflow)
self.cold.append(summary)

# Patch all cold entries with every known fact update so no stale
# values survive compression into the cold tier.
for key_name, new_val in self._fact_updates.items():
self.cold = _patch_cold_with_update(self.cold, key_name, new_val)

if len(self.cold) > self.cold_max:
merged = self.cold[0] + " | " + self.cold[1]
# Merge oldest two; newer content first so it survives truncation
merged = self.cold[1] + " | " + self.cold[0]
self.cold = [merged[:600]] + self.cold[2:]

def get_context(self, query: str, current_turn: int) -> List[Dict]:
Expand Down Expand Up @@ -116,8 +184,9 @@ def get_context(self, query: str, current_turn: int) -> List[Dict]:
return context

def reset(self) -> None:
self.hot = []
self.warm = []
self.warm_embs = []
self.cold = []
self.turn_count = 0
self.hot = []
self.warm = []
self.warm_embs = []
self.cold = []
self.turn_count = 0
self._fact_updates = {}