1
1
from collections import Counter
2
2
from logging import getLogger
3
+ from typing import Optional , Tuple
3
4
import pandas
4
5
import numpy
5
6
from .dataframe_helpers import dataframe_shuffle
@@ -449,14 +450,15 @@ def double_merge(d):
449
450
450
451
451
452
def train_test_apart_stratify (
452
- df ,
453
+ df : pandas . DataFrame ,
453
454
group ,
454
- test_size = 0.25 ,
455
- train_size = None ,
456
- stratify = None ,
457
- force = False ,
458
- random_state = None ,
459
- ):
455
+ test_size : Optional [float ] = 0.25 ,
456
+ train_size : Optional [float ] = None ,
457
+ stratify : Optional [str ] = None ,
458
+ force : bool = False ,
459
+ random_state : Optional [int ] = None ,
460
+ sorted_indices : bool = False ,
461
+ ) -> Tuple ["StreamingDataFrame" , "StreamingDataFrame" ]: # noqa: F821
460
462
"""
461
463
This split is for a specific case where data is linked
462
464
in one way. Let's assume we have two ids as we have
@@ -474,6 +476,8 @@ def train_test_apart_stratify(
474
476
:param force: if True, tries to get at least one example on the test side
475
477
for each value of the column *stratify*
476
478
:param random_state: seed for random generators
479
+ :param sorted_indices: sort index first,
480
+ see issue `41 <https://github.com/sdpython/pandas-streaming/issues/41>`
477
481
:return: Two see :class:`StreamingDataFrame
478
482
<pandas_streaming.df.dataframe.StreamingDataFrame>`, one
479
483
for train, one for test.
@@ -540,10 +544,15 @@ def train_test_apart_stratify(
540
544
541
545
split = {}
542
546
for _ , k in sorted_hist :
543
- not_assigned = [c for c in ids [k ] if c not in split ]
547
+ indices = sorted (ids [k ]) if sorted_indices else ids [k ]
548
+ not_assigned , assigned = [], []
549
+ for c in indices :
550
+ if c in split :
551
+ assigned .append (c )
552
+ else :
553
+ not_assigned .append (c )
544
554
if len (not_assigned ) == 0 :
545
555
continue
546
- assigned = [c for c in ids [k ] if c in split ]
547
556
nb_test = sum (split [c ] for c in assigned )
548
557
expected = min (len (ids [k ]), int (test_size * len (ids [k ]) + 0.5 )) - nb_test
549
558
if force and expected == 0 and nb_test == 0 :
0 commit comments