Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 180 additions & 3 deletions evaluation/stats.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
"""
evaluation/stats.py — Statistical helpers for multi-seed aggregation.
evaluation/stats.py — Statistical helpers for multi-seed aggregation and
forgetting-curve analysis.

Provides mean ± std + 95% confidence intervals over multiple benchmark runs.
Provides:
- mean ± std + 95% confidence intervals over multiple benchmark runs
- fit_forgetting_curve() — fits Ebbinghaus / exponential decay to recall@T
data and returns half-life, stability, and R² goodness-of-fit

Closes #6.
"""

import math
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple


# ── Basic statistics ──────────────────────────────────────────────────────────

def _mean(values: List[float]) -> float:
return sum(values) / len(values) if values else 0.0
Expand Down Expand Up @@ -72,3 +80,172 @@ def aggregate_checkpoint_series(
aggregate_metric([run[i] for run in series])
for i in range(n_checkpoints)
]


# ── Forgetting-curve fitting ──────────────────────────────────────────────────

def _r_squared(observed: List[float], predicted: List[float]) -> float:
"""Coefficient of determination R² ∈ (-∞, 1]; 1 = perfect fit."""
if len(observed) < 2:
return float("nan")
mean_obs = _mean(observed)
ss_tot = sum((y - mean_obs) ** 2 for y in observed)
ss_res = sum((y - y_hat) ** 2 for y, y_hat in zip(observed, predicted))
if ss_tot == 0:
return 1.0 if ss_res == 0 else float("-inf")
return 1.0 - ss_res / ss_tot


def _fit_exponential(
turns: List[float],
recalls: List[float],
) -> Tuple[float, float, float]:
"""
Fit R(t) = a · exp(−k · t) via log-linear least squares.

Returns (a, k, r_squared).
a — intercept (recall at t=0, ideally ≈ 1.0)
k — decay rate (higher = faster forgetting)
"""
# Filter out zero/negative recalls to avoid log(0)
valid = [(t, r) for t, r in zip(turns, recalls) if r > 0]
if len(valid) < 2:
return (float("nan"), float("nan"), float("nan"))

xs = [t for t, _ in valid]
ys = [math.log(r) for _, r in valid]

# Linear regression on log(R) = log(a) - k*t
n = len(xs)
sx = sum(xs)
sy = sum(ys)
sxx = sum(x * x for x in xs)
sxy = sum(x * y for x, y in zip(xs, ys))
denom = n * sxx - sx * sx
if denom == 0:
return (float("nan"), float("nan"), float("nan"))

k = -(n * sxy - sx * sy) / denom
log_a = (sy - (-k) * sx) / n # using -k because slope is -k
a = math.exp(log_a)

predicted = [a * math.exp(-k * t) for t in turns]
r2 = _r_squared(recalls, predicted)
return (a, k, r2)


def _fit_ebbinghaus(
turns: List[float],
recalls: List[float],
t_max: float,
) -> Tuple[float, float]:
"""
Fit R(t) = exp(−t_norm / (S · sqrt(1 + t_norm))) by grid-searching over S.
t_norm = t / t_max so that t ∈ [0, 1].

Returns (S, r_squared).
S — stability constant (higher = slower forgetting)
"""
if len(turns) < 2:
return (float("nan"), float("nan"))

t_norm_list = [t / max(t_max, 1) for t in turns]

def _predict(s: float) -> List[float]:
result = []
for tn in t_norm_list:
if tn <= 0:
result.append(1.0)
else:
denom = s * math.sqrt(1.0 + tn)
result.append(math.exp(-tn / denom))
return result

best_s = 1.0
best_r2 = float("-inf")

# Coarse + fine grid search over S ∈ [0.01, 20]
for s in [i * 0.1 for i in range(1, 201)]:
predicted = _predict(s)
r2 = _r_squared(recalls, predicted)
if not math.isnan(r2) and r2 > best_r2:
best_r2 = r2
best_s = s

return (best_s, best_r2)


def fit_forgetting_curve(
checkpoints: List[int],
recalls: List[float],
) -> Dict:
"""
Fit forgetting-curve models to a backend's recall@T time-series and return
interpretable memory-stability statistics.

Models fitted
-------------
exponential : R(t) = a · exp(−k · t)
Classic single-parameter decay (Jost 1897).
ebbinghaus : R(t) = exp(−t_norm / (S · √(1 + t_norm)))
Two-parameter Ebbinghaus (1885) forgetting curve.

Parameters
----------
checkpoints : list of turn numbers at which recall was measured
recalls : list of recall values ∈ [0, 1] corresponding to each checkpoint

Returns
-------
dict with keys:
exponential:
a — initial recall estimate at t=0
k — decay rate (nats per turn)
half_life — turns until recall halves (ln(2)/k)
r2 — R² goodness-of-fit
ebbinghaus:
stability — S parameter (higher = more stable memory)
half_life — turns until recall drops to 0.5
r2 — R² goodness-of-fit
checkpoints : input turns (echoed for convenience)
recalls : input recalls (echoed for convenience)
"""
if len(checkpoints) != len(recalls) or len(checkpoints) < 2:
return {"error": "Need at least 2 (checkpoint, recall) pairs."}

turns = [float(t) for t in checkpoints]
t_max = max(turns)

# ── Exponential fit ──────────────────────────────────────────────────────
a, k, r2_exp = _fit_exponential(turns, recalls)
half_life_exp = math.log(2) / k if (not math.isnan(k) and k > 0) else float("inf")

# ── Ebbinghaus fit ───────────────────────────────────────────────────────
S, r2_ebb = _fit_ebbinghaus(turns, recalls, t_max)

# Half-life for Ebbinghaus: solve exp(-tn / (S * sqrt(1+tn))) = 0.5
# Numerically: scan t_norm values
half_life_ebb = float("inf")
if not math.isnan(S):
for step in range(1, 10001):
tn = step / 100.0
val = math.exp(-tn / (S * math.sqrt(1.0 + tn)))
if val <= 0.5:
half_life_ebb = round(tn * t_max, 2)
break

return {
"exponential": {
"a": round(a, 4) if not math.isnan(a) else None,
"k": round(k, 6) if not math.isnan(k) else None,
"half_life": round(half_life_exp, 2) if not math.isinf(half_life_exp) else None,
"r2": round(r2_exp, 4) if not math.isnan(r2_exp) else None,
},
"ebbinghaus": {
"stability": round(S, 4) if not math.isnan(S) else None,
"half_life": half_life_ebb if not math.isinf(half_life_ebb) else None,
"r2": round(r2_ebb, 4) if not math.isnan(r2_ebb) else None,
},
"checkpoints": checkpoints,
"recalls": [round(r, 4) for r in recalls],
}
47 changes: 47 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
Realistic chunked RAG backend:
python main.py --backends naive rag_chunked cascading

Forgetting-curve analysis (fit Ebbinghaus + exponential to recall@T data):
python main.py --fit-curves
python main.py --seeds 5 --fit-curves

Other options:
python main.py --turns 50 --backends naive rag --log
python main.py --list-providers
Expand Down Expand Up @@ -66,6 +70,9 @@ def main() -> None:
parser.add_argument("--decay", type=str, default="ebbinghaus",
choices=["ebbinghaus", "exponential", "linear", "default"],
help="Temporal decay function for CascadingMemory warm tier")
parser.add_argument("--fit-curves", action="store_true",
help="After benchmarking, fit Ebbinghaus + exponential decay curves "
"to recall@T data and report half-life / stability / R²")
args = parser.parse_args()

# ── List providers ────────────────────────────────────────────────────────
Expand Down Expand Up @@ -161,11 +168,51 @@ def main() -> None:
})
print(f"Experiment logged -> {path}")

# ── Forgetting-curve analysis ─────────────────────────────────────────────
if args.fit_curves:
from evaluation.stats import fit_forgetting_curve
checkpoints = sorted(args.checkpoints)
print("\nFORGETTING CURVE FIT (Ebbinghaus + Exponential)")
print("-" * 65)
if multi_seed:
for name in args.backends:
if name not in aggregated:
continue
mean_recalls = [stat["mean"] for stat in aggregated[name]["recall"]]
fit = fit_forgetting_curve(checkpoints, mean_recalls)
_print_curve_fit(name, fit)
else:
for name in args.backends:
if name not in display:
continue
fit = fit_forgetting_curve(checkpoints, display[name]["recall"])
_print_curve_fit(name, fit)

print("Visualise: streamlit run dashboard.py")


# ── Output helpers ────────────────────────────────────────────────────────────


def _print_curve_fit(backend: str, fit: dict) -> None:
if "error" in fit:
print(f" {backend:<14} {fit['error']}")
return
exp = fit["exponential"]
ebb = fit["ebbinghaus"]
hl_exp = f"{exp['half_life']:.1f} turns" if exp["half_life"] is not None else "N/A"
hl_ebb = f"{ebb['half_life']:.1f} turns" if ebb["half_life"] is not None else "N/A"
r2_exp = f"{exp['r2']:.3f}" if exp["r2"] is not None else "N/A"
r2_ebb = f"{ebb['r2']:.3f}" if ebb["r2"] is not None else "N/A"
stab = f"{ebb['stability']:.4f}" if ebb["stability"] is not None else "N/A"
k_val = f"{exp['k']:.6f}" if exp["k"] is not None else "N/A"
print(f" {backend}")
print(f" Exponential k={k_val} half-life={hl_exp} R²={r2_exp}")
print(f" Ebbinghaus S={stab} half-life={hl_ebb} R²={r2_ebb}")




def _print_single_seed_results(display: dict, backends: list) -> None:
checkpoints = display["checkpoints"]
col = " ".join(f"T={c:3d}" for c in checkpoints)
Expand Down