Skip to content

Commit e1936d9

Browse files
authored
FIX: In the regression setting, cv=LeaveOneGroupOut() and cv=LeavePGroupsOut() are not working (#696)
1 parent b48b6b2 commit e1936d9

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

mapie/estimator/regressor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,12 @@ def fit_single_estimator(
528528
**fit_params
529529
) -> EnsembleRegressor:
530530

531-
self.use_split_method_ = _check_no_agg_cv(X, self.cv, self.no_agg_cv_)
531+
self.use_split_method_ = _check_no_agg_cv(
532+
X,
533+
self.cv,
534+
self.no_agg_cv_,
535+
groups=groups
536+
)
532537
single_estimator_: RegressorMixin
533538

534539
if self.cv == "prefit":

mapie/tests/test_regression.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.linear_model import LinearRegression
1717
from sklearn.model_selection import (
1818
GroupKFold, KFold, LeaveOneOut, PredefinedSplit, ShuffleSplit,
19-
train_test_split
19+
train_test_split, LeaveOneGroupOut, LeavePGroupsOut
2020
)
2121
from sklearn.pipeline import Pipeline, make_pipeline
2222
from sklearn.preprocessing import OneHotEncoder
@@ -290,6 +290,30 @@ def test_predict_output_shape(
290290
assert y_pis.shape == (X.shape[0], 2, n_alpha)
291291

292292

293+
@pytest.mark.parametrize(
294+
"cv, n_groups",
295+
[
296+
(LeaveOneGroupOut(), 5),
297+
(LeavePGroupsOut(2), 10),
298+
],
299+
)
300+
def test_group_cv_fit_runs_regressor(cv, n_groups):
301+
"""
302+
`_MapieRegressor` should accept group‑based CV splitters
303+
(LeaveOneGroupOut, LeavePGroupsOut) without raising.
304+
"""
305+
X, y = make_regression(
306+
n_samples=n_groups * 30,
307+
n_features=5,
308+
noise=0.1,
309+
random_state=42,
310+
)
311+
groups = np.repeat(np.arange(n_groups), 30)
312+
313+
# Ensuring `.fit` does not raise
314+
_MapieRegressor(cv=cv).fit(X, y, groups=groups)
315+
316+
293317
@pytest.mark.parametrize("delta", [0.6, 0.8])
294318
@pytest.mark.parametrize("n_calib", [10 + i for i in range(13)] + [50, 100])
295319
def test_coverage_validity(delta: float, n_calib: int) -> None:

0 commit comments

Comments
 (0)