@@ -17,6 +17,7 @@ def get_random_splitter(ratios: List[float], seed: int):
17
17
def get_stratified_splitter (ratios : List [float ], seed : int ):
18
18
return Splitter (splitter_impl = StratifiedSplitterImpl (ratios , seed = seed ))
19
19
20
+ @staticmethod
20
21
def get_nested_cv_splitter (num_outer_loop_folds : int = 5 , num_inner_loop_folds : int = 2 ,
21
22
inner_stratification : bool = True , outer_stratification : bool = True ,
22
23
target_pos : int = 1 , shuffle : bool = True , seed : int = 1 ):
@@ -29,6 +30,16 @@ def get_nested_cv_splitter(num_outer_loop_folds: int = 5, num_inner_loop_folds:
29
30
seed = seed )
30
31
return Splitter (splitter_impl = splitter_impl )
31
32
33
+ @staticmethod
34
+ def get_cv_splitter (num_folds : int = 5 , stratification : bool = True , target_pos : int = 1 , shuffle : bool = True , seed : int = 1 ):
35
+ splitter_impl = CVSplitterImpl (num_folds = num_folds ,
36
+ stratification = stratification ,
37
+ target_pos = target_pos ,
38
+ shuffle = shuffle ,
39
+ seed = seed )
40
+ return Splitter (splitter_impl = splitter_impl )
41
+
42
+
32
43
class SplitterIF (ABC ):
33
44
34
45
@abstractmethod
@@ -91,7 +102,7 @@ class StratifiedSplitterImpl(SplitterIF):
91
102
92
103
def __init__ (self , ratios : List [float ], seed : Optional [int ] = None ):
93
104
self .ratios = ratios
94
- self .seed = seed
105
+ self .seed = seed
95
106
96
107
def split (self , dataset_iterator : DatasetIteratorIF ) -> List [DatasetIteratorIF ]:
97
108
dataset_length = len (dataset_iterator )
@@ -110,9 +121,9 @@ def _determine_split_indices(self, dataset_length: int, ratios: List[float], dat
110
121
# split the data set until the desired number of splits is reached
111
122
for split_ratio in ratios [:- 1 ]:
112
123
indices_split , indices_remaining , _ , targets_remaining = train_test_split (indices_remaining ,
113
- targets_remaining ,
114
- train_size = int (initial_length * split_ratio ),
115
- stratify = targets_remaining , random_state = self .seed , shuffle = True )
124
+ targets_remaining ,
125
+ train_size = int (initial_length * split_ratio ),
126
+ stratify = targets_remaining , random_state = self .seed , shuffle = True )
116
127
split_indices .append (indices_split )
117
128
# any remaining indices are added to the last split
118
129
split_indices .append (indices_remaining )
@@ -173,3 +184,31 @@ def get_indices(self, dataset_iterator: DatasetIteratorIF) -> Tuple[List[List[in
173
184
174
185
return outer_folds_indices , inner_fold_indices
175
186
187
+
188
+ class CVSplitterImpl (SplitterIF ):
189
+
190
+ def __init__ (self ,
191
+ num_folds : int = 5 ,
192
+ stratification : bool = True ,
193
+ target_pos : int = 1 ,
194
+ shuffle : bool = True ,
195
+ seed : int = 1 ):
196
+ self .num_folds = num_folds
197
+ self .random_state = np .random .RandomState (seed = seed ) if shuffle else None
198
+ self .target_pos = target_pos
199
+
200
+ if stratification :
201
+ self .splitter = StratifiedKFold (n_splits = num_folds , shuffle = shuffle , random_state = self .random_state )
202
+ else :
203
+ self .splitter = KFold (n_splits = num_folds , shuffle = shuffle , random_state = self .random_state )
204
+
205
+ def split (self , dataset_iterator : DatasetIteratorIF ) -> List [DatasetIteratorView ]:
206
+ targets = np .array ([sample [self .target_pos ] for sample in dataset_iterator ])
207
+ folds_indices = [fold [1 ].tolist () for fold in self .splitter .split (X = np .zeros (len (targets )), y = targets )]
208
+ fold_iterators = [DatasetIteratorView (dataset_iterator , fold_indices ) for fold_indices in folds_indices ]
209
+ return fold_iterators
210
+
211
+ def get_indices (self , dataset_iterator : DatasetIteratorIF ) -> List [List [int ]]:
212
+ folds = self .split (dataset_iterator )
213
+ folds_indices = [fold .indices for fold in folds ]
214
+ return folds_indices
0 commit comments