|
17 | 17 | import numpy as np |
18 | 18 | import pandas as pd |
19 | 19 |
|
20 | | -from diff_diff.linalg import compute_r_squared, compute_robust_vcov, solve_ols |
| 20 | +from diff_diff.linalg import ( |
| 21 | + LinearRegression, |
| 22 | + compute_r_squared, |
| 23 | + compute_robust_vcov, |
| 24 | + solve_ols, |
| 25 | +) |
21 | 26 | from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect |
22 | 27 | from diff_diff.utils import ( |
23 | 28 | WildBootstrapResults, |
@@ -262,56 +267,45 @@ def fit( |
262 | 267 | X = np.column_stack([X, dummies[col].values.astype(float)]) |
263 | 268 | var_names.append(col) |
264 | 269 |
|
265 | | - # Fit OLS using unified backend |
266 | | - coefficients, residuals, fitted, vcov = solve_ols( |
267 | | - X, y, return_fitted=True, return_vcov=False |
268 | | - ) |
269 | | - r_squared = compute_r_squared(y, residuals) |
270 | | - |
271 | | - # Extract ATT (coefficient on interaction term) |
| 270 | + # Extract ATT index (coefficient on interaction term) |
272 | 271 | att_idx = 3 # Index of interaction term |
273 | 272 | att_var_name = f"{treatment}:{time}" |
274 | 273 | assert var_names[att_idx] == att_var_name, ( |
275 | 274 | f"ATT index mismatch: expected '{att_var_name}' at index {att_idx}, " |
276 | 275 | f"but found '{var_names[att_idx]}'" |
277 | 276 | ) |
278 | | - att = coefficients[att_idx] |
279 | 277 |
|
280 | | - # Compute degrees of freedom (used for analytical inference) |
281 | | - df = len(y) - X.shape[1] - n_absorbed_effects |
| 278 | + # Always use LinearRegression for initial fit (unified code path) |
| 279 | + # For wild bootstrap, we don't need cluster SEs from the initial fit |
| 280 | + cluster_ids = data[self.cluster].values if self.cluster is not None else None |
| 281 | + reg = LinearRegression( |
| 282 | + include_intercept=False, # Intercept already in X |
| 283 | + robust=self.robust, |
| 284 | + cluster_ids=cluster_ids if self.inference != "wild_bootstrap" else None, |
| 285 | + alpha=self.alpha, |
| 286 | + ).fit(X, y, df_adjustment=n_absorbed_effects) |
| 287 | + |
| 288 | + coefficients = reg.coefficients_ |
| 289 | + residuals = reg.residuals_ |
| 290 | + fitted = reg.fitted_values_ |
| 291 | + att = coefficients[att_idx] |
282 | 292 |
|
283 | | - # Compute standard errors and inference |
| 293 | + # Get inference - either from bootstrap or analytical |
284 | 294 | if self.inference == "wild_bootstrap" and self.cluster is not None: |
285 | | - # Wild cluster bootstrap for few-cluster inference |
286 | | - cluster_ids = data[self.cluster].values |
| 295 | + # Override with wild cluster bootstrap inference |
287 | 296 | se, p_value, conf_int, t_stat, vcov, _ = self._run_wild_bootstrap_inference( |
288 | 297 | X, y, residuals, cluster_ids, att_idx |
289 | 298 | ) |
290 | | - elif self.cluster is not None: |
291 | | - cluster_ids = data[self.cluster].values |
292 | | - vcov = compute_robust_vcov(X, residuals, cluster_ids) |
293 | | - se = np.sqrt(vcov[att_idx, att_idx]) |
294 | | - t_stat = att / se |
295 | | - p_value = compute_p_value(t_stat, df=df) |
296 | | - conf_int = compute_confidence_interval(att, se, self.alpha, df=df) |
297 | | - elif self.robust: |
298 | | - vcov = compute_robust_vcov(X, residuals) |
299 | | - se = np.sqrt(vcov[att_idx, att_idx]) |
300 | | - t_stat = att / se |
301 | | - p_value = compute_p_value(t_stat, df=df) |
302 | | - conf_int = compute_confidence_interval(att, se, self.alpha, df=df) |
303 | 299 | else: |
304 | | - # Classical OLS standard errors |
305 | | - n = len(y) |
306 | | - k = X.shape[1] |
307 | | - mse = np.sum(residuals**2) / (n - k) |
308 | | - # Use solve() instead of inv() for numerical stability |
309 | | - # solve(A, B) computes X where AX=B, so this yields (X'X)^{-1} * mse |
310 | | - vcov = np.linalg.solve(X.T @ X, mse * np.eye(k)) |
311 | | - se = np.sqrt(vcov[att_idx, att_idx]) |
312 | | - t_stat = att / se |
313 | | - p_value = compute_p_value(t_stat, df=df) |
314 | | - conf_int = compute_confidence_interval(att, se, self.alpha, df=df) |
| 300 | + # Use analytical inference from LinearRegression |
| 301 | + vcov = reg.vcov_ |
| 302 | + inference = reg.get_inference(att_idx) |
| 303 | + se = inference.se |
| 304 | + t_stat = inference.t_stat |
| 305 | + p_value = inference.p_value |
| 306 | + conf_int = inference.conf_int |
| 307 | + |
| 308 | + r_squared = compute_r_squared(y, residuals) |
315 | 309 |
|
316 | 310 | # Count observations |
317 | 311 | n_treated = int(np.sum(d)) |
|
0 commit comments