Skip to content

Commit f9178da

Browse files
authored
Merge pull request #66 from igerber/claude/linear-regression-helper-tuifh
2 parents 01d2306 + d283695 commit f9178da

File tree

9 files changed

+1240
-91
lines changed

9 files changed

+1240
-91
lines changed

CLAUDE.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,12 @@ pytest tests/test_rust_backend.py -v
9696
- `bacon_decompose()` - Convenience function for quick decomposition
9797
- Integrated with `TwoWayFixedEffects.decompose()` method
9898

99-
- **`diff_diff/linalg.py`** - Unified linear algebra backend (v1.4.0):
99+
- **`diff_diff/linalg.py`** - Unified linear algebra backend (v1.4.0+):
100100
- `solve_ols()` - OLS solver using scipy's gelsy LAPACK driver (QR-based, faster than SVD)
101101
- `compute_robust_vcov()` - Vectorized HC1 and cluster-robust variance-covariance estimation
102102
- `compute_r_squared()` - R-squared and adjusted R-squared computation
103+
- `LinearRegression` - High-level OLS helper class with unified coefficient extraction and inference
104+
- `InferenceResult` - Dataclass container for coefficient-level inference (SE, t-stat, p-value, CI)
103105
- Single optimization point for all estimators (reduces code duplication)
104106
- Cluster-robust SEs use pandas groupby instead of O(n × clusters) loop
105107

@@ -270,7 +272,7 @@ Tests mirror the source modules:
270272
- `tests/test_sun_abraham.py` - Tests for SunAbraham interaction-weighted estimator
271273
- `tests/test_triple_diff.py` - Tests for Triple Difference (DDD) estimator
272274
- `tests/test_bacon.py` - Tests for Goodman-Bacon decomposition
273-
- `tests/test_linalg.py` - Tests for unified OLS backend and robust variance estimation
275+
- `tests/test_linalg.py` - Tests for unified OLS backend, robust variance estimation, LinearRegression helper, and InferenceResult
274276
- `tests/test_utils.py` - Tests for parallel trends, robust SE, synthetic weights
275277
- `tests/test_diagnostics.py` - Tests for placebo tests
276278
- `tests/test_wild_bootstrap.py` - Tests for wild cluster bootstrap

ROADMAP.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ Ongoing maintenance and developer experience.
224224
### Code Quality
225225

226226
- Extract shared within-transformation logic to utils
227-
- Consolidate linear regression helpers
227+
- ~~Consolidate linear regression helpers~~ ✓ Done (v2.1): Added `LinearRegression` helper class and `InferenceResult` dataclass in `linalg.py`. All major estimators (DifferenceInDifferences, TwoWayFixedEffects, SunAbraham, TripleDifference) now use the unified helper for coefficient extraction and inference.
228228
- Consider splitting `staggered.py` (1800+ lines)
229229

230230
### Documentation

diff_diff/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
run_all_placebo_tests,
3131
run_placebo_test,
3232
)
33+
from diff_diff.linalg import (
34+
InferenceResult,
35+
LinearRegression,
36+
)
3337
from diff_diff.estimators import (
3438
DifferenceInDifferences,
3539
MultiPeriodDiD,
@@ -199,4 +203,7 @@
199203
"plot_pretrends_power",
200204
# Rust backend
201205
"HAS_RUST_BACKEND",
206+
# Linear algebra helpers
207+
"LinearRegression",
208+
"InferenceResult",
202209
]

diff_diff/estimators.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
import numpy as np
1818
import pandas as pd
1919

20-
from diff_diff.linalg import compute_r_squared, compute_robust_vcov, solve_ols
20+
from diff_diff.linalg import (
21+
LinearRegression,
22+
compute_r_squared,
23+
compute_robust_vcov,
24+
solve_ols,
25+
)
2126
from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect
2227
from diff_diff.utils import (
2328
WildBootstrapResults,
@@ -262,56 +267,45 @@ def fit(
262267
X = np.column_stack([X, dummies[col].values.astype(float)])
263268
var_names.append(col)
264269

265-
# Fit OLS using unified backend
266-
coefficients, residuals, fitted, vcov = solve_ols(
267-
X, y, return_fitted=True, return_vcov=False
268-
)
269-
r_squared = compute_r_squared(y, residuals)
270-
271-
# Extract ATT (coefficient on interaction term)
270+
# Extract ATT index (coefficient on interaction term)
272271
att_idx = 3 # Index of interaction term
273272
att_var_name = f"{treatment}:{time}"
274273
assert var_names[att_idx] == att_var_name, (
275274
f"ATT index mismatch: expected '{att_var_name}' at index {att_idx}, "
276275
f"but found '{var_names[att_idx]}'"
277276
)
278-
att = coefficients[att_idx]
279277

280-
# Compute degrees of freedom (used for analytical inference)
281-
df = len(y) - X.shape[1] - n_absorbed_effects
278+
# Always use LinearRegression for initial fit (unified code path)
279+
# For wild bootstrap, we don't need cluster SEs from the initial fit
280+
cluster_ids = data[self.cluster].values if self.cluster is not None else None
281+
reg = LinearRegression(
282+
include_intercept=False, # Intercept already in X
283+
robust=self.robust,
284+
cluster_ids=cluster_ids if self.inference != "wild_bootstrap" else None,
285+
alpha=self.alpha,
286+
).fit(X, y, df_adjustment=n_absorbed_effects)
287+
288+
coefficients = reg.coefficients_
289+
residuals = reg.residuals_
290+
fitted = reg.fitted_values_
291+
att = coefficients[att_idx]
282292

283-
# Compute standard errors and inference
293+
# Get inference - either from bootstrap or analytical
284294
if self.inference == "wild_bootstrap" and self.cluster is not None:
285-
# Wild cluster bootstrap for few-cluster inference
286-
cluster_ids = data[self.cluster].values
295+
# Override with wild cluster bootstrap inference
287296
se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference(
288297
X, y, residuals, cluster_ids, att_idx
289298
)
290-
elif self.cluster is not None:
291-
cluster_ids = data[self.cluster].values
292-
vcov = compute_robust_vcov(X, residuals, cluster_ids)
293-
se = np.sqrt(vcov[att_idx, att_idx])
294-
t_stat = att / se
295-
p_value = compute_p_value(t_stat, df=df)
296-
conf_int = compute_confidence_interval(att, se, self.alpha, df=df)
297-
elif self.robust:
298-
vcov = compute_robust_vcov(X, residuals)
299-
se = np.sqrt(vcov[att_idx, att_idx])
300-
t_stat = att / se
301-
p_value = compute_p_value(t_stat, df=df)
302-
conf_int = compute_confidence_interval(att, se, self.alpha, df=df)
303299
else:
304-
# Classical OLS standard errors
305-
n = len(y)
306-
k = X.shape[1]
307-
mse = np.sum(residuals**2) / (n - k)
308-
# Use solve() instead of inv() for numerical stability
309-
# solve(A, B) computes X where AX=B, so this yields (X'X)^{-1} * mse
310-
vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
311-
se = np.sqrt(vcov[att_idx, att_idx])
312-
t_stat = att / se
313-
p_value = compute_p_value(t_stat, df=df)
314-
conf_int = compute_confidence_interval(att, se, self.alpha, df=df)
300+
# Use analytical inference from LinearRegression
301+
vcov = reg.vcov_
302+
inference = reg.get_inference(att_idx)
303+
se = inference.se
304+
t_stat = inference.t_stat
305+
p_value = inference.p_value
306+
conf_int = inference.conf_int
307+
308+
r_squared = compute_r_squared(y, residuals)
315309

316310
# Count observations
317311
n_treated = int(np.sum(d))

0 commit comments

Comments
 (0)