Skip to content

Commit d0813a3

Browse files
committed
Use sorted indices
1 parent 9753f32 commit d0813a3

File tree

2 files changed

+49
-9
lines changed

2 files changed

+49
-9
lines changed

_unittests/ut_df/test_connex_split_cat.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,37 @@ def test_cat_strat(self):
3939
lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError
4040
)
4141

42+
def test_cat_strat_sorted(self):
43+
df = pandas.DataFrame(
44+
[
45+
dict(a=1, b="e"),
46+
dict(a=2, b="e"),
47+
dict(a=4, b="f"),
48+
dict(a=8, b="f"),
49+
dict(a=32, b="f"),
50+
dict(a=16, b="f"),
51+
]
52+
)
53+
54+
train, test = train_test_apart_stratify(
55+
df, group="a", stratify="b", test_size=0.5, sorted_indices=True
56+
)
57+
self.assertEqual(train.shape[1], test.shape[1])
58+
self.assertEqual(train.shape[0] + test.shape[0], df.shape[0])
59+
c1 = Counter(train["b"])
60+
c2 = Counter(train["b"])
61+
self.assertEqual(c1, c2)
62+
63+
self.assertRaise(
64+
lambda: train_test_apart_stratify(
65+
df, group=None, stratify="b", test_size=0.5, sorted_indices=True
66+
),
67+
ValueError,
68+
)
69+
self.assertRaise(
70+
lambda: train_test_apart_stratify(df, group="b", test_size=0.5), ValueError
71+
)
72+
4273
def test_cat_strat_multi(self):
4374
df = pandas.DataFrame(
4475
[

pandas_streaming/df/connex_split.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import Counter
22
from logging import getLogger
3+
from typing import Optional, Tuple
34
import pandas
45
import numpy
56
from .dataframe_helpers import dataframe_shuffle
@@ -449,14 +450,15 @@ def double_merge(d):
449450

450451

451452
def train_test_apart_stratify(
452-
df,
453+
df: pandas.DataFrame,
453454
group,
454-
test_size=0.25,
455-
train_size=None,
456-
stratify=None,
457-
force=False,
458-
random_state=None,
459-
):
455+
test_size: Optional[float] = 0.25,
456+
train_size: Optional[float] = None,
457+
stratify: Optional[str] = None,
458+
force: bool = False,
459+
random_state: Optional[int] = None,
460+
sorted_indices: bool = False,
461+
) -> Tuple["StreamingDataFrame", "StreamingDataFrame"]: # noqa: F821
460462
"""
461463
This split is for a specific case where data is linked
462464
in one way. Let's assume we have two ids as we have
@@ -474,6 +476,8 @@ def train_test_apart_stratify(
474476
:param force: if True, tries to get at least one example on the test side
475477
for each value of the column *stratify*
476478
:param random_state: seed for random generators
479+
:param sorted_indices: sort index first,
480+
see issue `41 <https://github.com/sdpython/pandas-streaming/issues/41>`
477481
:return: Two see :class:`StreamingDataFrame
478482
<pandas_streaming.df.dataframe.StreamingDataFrame>`, one
479483
for train, one for test.
@@ -540,10 +544,15 @@ def train_test_apart_stratify(
540544

541545
split = {}
542546
for _, k in sorted_hist:
543-
not_assigned = [c for c in ids[k] if c not in split]
547+
indices = sorted(ids[k]) if sorted_indices else ids[k]
548+
not_assigned, assigned = [], []
549+
for c in indices:
550+
if c in split:
551+
assigned.append(c)
552+
else:
553+
not_assigned.append(c)
544554
if len(not_assigned) == 0:
545555
continue
546-
assigned = [c for c in ids[k] if c in split]
547556
nb_test = sum(split[c] for c in assigned)
548557
expected = min(len(ids[k]), int(test_size * len(ids[k]) + 0.5)) - nb_test
549558
if force and expected == 0 and nb_test == 0:

0 commit comments

Comments
 (0)