Skip to content

Commit

Permalink
Merge pull request scikit-learn#6379 from lesteve/fix-stratified-shuf…
Browse files Browse the repository at this point in the history
…fle-split-train-test-overlap

[MRG+1] fix StratifiedShuffleSplit train and test overlap
  • Loading branch information
amueller committed Feb 29, 2016
2 parents 1049642 + 07728d9 commit 150afe6
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 6 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ Bug fixes
``transform`` or ``predict_proba`` are called on the non-fitted estimator.
by `Sebastian Raschka`_.

- Fixed bug in :class:`model_selection.StratifiedShuffleSplit`
where train and test sample could overlap in some edge cases,
see `#6121 <https://github.com/scikit-learn/scikit-learn/issues/6121>`_ for
more details. By `Loic Esteve`_.

API changes summary
-------------------

Expand Down
11 changes: 8 additions & 3 deletions sklearn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,14 +1010,19 @@ def _iter_indices(self):
# Because of rounding issues (as n_train and n_test are not
# dividers of the number of elements per class), we may end
# up here with less samples in train and test than asked for.
if len(train) < self.n_train or len(test) < self.n_test:
if len(train) + len(test) < self.n_train + self.n_test:
# We complete by affecting randomly the missing indexes
missing_idx = np.where(bincount(train + test,
minlength=len(self.y)) == 0,
)[0]
missing_idx = rng.permutation(missing_idx)
train.extend(missing_idx[:(self.n_train - len(train))])
test.extend(missing_idx[-(self.n_test - len(test)):])
n_missing_train = self.n_train - len(train)
n_missing_test = self.n_test - len(test)

if n_missing_train > 0:
train.extend(missing_idx[:n_missing_train])
if n_missing_test > 0:
test.extend(missing_idx[-n_missing_test:])

train = rng.permutation(train)
test = rng.permutation(test)
Expand Down
11 changes: 8 additions & 3 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,13 +1108,18 @@ def _iter_indices(self, X, y, labels=None):
# Because of rounding issues (as n_train and n_test are not
# dividers of the number of elements per class), we may end
# up here with less samples in train and test than asked for.
if len(train) < n_train or len(test) < n_test:
if len(train) + len(test) < n_train + n_test:
# We complete by affecting randomly the missing indexes
missing_indices = np.where(bincount(train + test,
minlength=len(y)) == 0)[0]
missing_indices = rng.permutation(missing_indices)
train.extend(missing_indices[:(n_train - len(train))])
test.extend(missing_indices[-(n_test - len(test)):])
n_missing_train = n_train - len(train)
n_missing_test = n_test - len(test)

if n_missing_train > 0:
train.extend(missing_indices[:n_missing_train])
if n_missing_test > 0:
test.extend(missing_indices[-n_missing_test:])

train = rng.permutation(train)
test = rng.permutation(test)
Expand Down
14 changes: 14 additions & 0 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,20 @@ def assert_counts_are_ok(idx_counts, p):
assert_counts_are_ok(test_counts, ex_test_p)


def test_stratified_shuffle_split_overlap_train_test_bug():
# See https://github.com/scikit-learn/scikit-learn/issues/6121 for
# the original bug report
y = [0, 1, 2, 3] * 3 + [4, 5] * 5
X = np.ones_like(y)

splits = StratifiedShuffleSplit(n_iter=1,
test_size=0.5, random_state=0)

train, test = next(iter(splits.split(X=X, y=y)))

assert_array_equal(np.intersect1d(train, test), [])


def test_predefinedsplit_with_kfold_split():
# Check that PredefinedSplit can reproduce a split generated by Kfold.
folds = -1 * np.ones(10)
Expand Down
12 changes: 12 additions & 0 deletions sklearn/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,18 @@ def assert_counts_are_ok(idx_counts, p):
assert_counts_are_ok(test_counts, ex_test_p)


def test_stratified_shuffle_split_overlap_train_test_bug():
# See https://github.com/scikit-learn/scikit-learn/issues/6121 for
# the original bug report
labels = [0, 1, 2, 3] * 3 + [4, 5] * 5

splits = cval.StratifiedShuffleSplit(labels, n_iter=1,
test_size=0.5, random_state=0)
train, test = next(iter(splits))

assert_array_equal(np.intersect1d(train, test), [])


def test_predefinedsplit_with_kfold_split():
# Check that PredefinedSplit can reproduce a split generated by Kfold.
folds = -1 * np.ones(10)
Expand Down

0 comments on commit 150afe6

Please sign in to comment.