Skip to content

Commit 6a5bd08

Browse files
committed
implemented CVSplitterImpl
1 parent cf74396 commit 6a5bd08

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

src/data_stack/dataset/splitter.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def get_random_splitter(ratios: List[float], seed: int):
1717
def get_stratified_splitter(ratios: List[float], seed: int):
1818
return Splitter(splitter_impl=StratifiedSplitterImpl(ratios, seed=seed))
1919

20+
@staticmethod
2021
def get_nested_cv_splitter(num_outer_loop_folds: int = 5, num_inner_loop_folds: int = 2,
2122
inner_stratification: bool = True, outer_stratification: bool = True,
2223
target_pos: int = 1, shuffle: bool = True, seed: int = 1):
@@ -29,6 +30,16 @@ def get_nested_cv_splitter(num_outer_loop_folds: int = 5, num_inner_loop_folds:
2930
seed=seed)
3031
return Splitter(splitter_impl=splitter_impl)
3132

33+
@staticmethod
34+
def get_cv_splitter(num_folds: int = 5, stratification: bool = True, target_pos: int = 1, shuffle: bool = True, seed: int = 1):
35+
splitter_impl = CVSplitterImpl(num_folds=num_folds,
36+
stratification=stratification,
37+
target_pos=target_pos,
38+
shuffle=shuffle,
39+
seed=seed)
40+
return Splitter(splitter_impl=splitter_impl)
41+
42+
3243
class SplitterIF(ABC):
3344

3445
@abstractmethod
@@ -91,7 +102,7 @@ class StratifiedSplitterImpl(SplitterIF):
91102

92103
def __init__(self, ratios: List[float], seed: Optional[int] = None):
93104
self.ratios = ratios
94-
self.seed = seed
105+
self.seed = seed
95106

96107
def split(self, dataset_iterator: DatasetIteratorIF) -> List[DatasetIteratorIF]:
97108
dataset_length = len(dataset_iterator)
@@ -110,9 +121,9 @@ def _determine_split_indices(self, dataset_length: int, ratios: List[float], dat
110121
# split the data set until the desired number of splits is reached
111122
for split_ratio in ratios[:-1]:
112123
indices_split, indices_remaining, _, targets_remaining = train_test_split(indices_remaining,
113-
targets_remaining,
114-
train_size = int(initial_length*split_ratio),
115-
stratify=targets_remaining, random_state=self.seed, shuffle=True)
124+
targets_remaining,
125+
train_size=int(initial_length*split_ratio),
126+
stratify=targets_remaining, random_state=self.seed, shuffle=True)
116127
split_indices.append(indices_split)
117128
# any remaining indices are added to the last split
118129
split_indices.append(indices_remaining)
@@ -173,3 +184,31 @@ def get_indices(self, dataset_iterator: DatasetIteratorIF) -> Tuple[List[List[in
173184

174185
return outer_folds_indices, inner_fold_indices
175186

187+
188+
class CVSplitterImpl(SplitterIF):
189+
190+
def __init__(self,
191+
num_folds: int = 5,
192+
stratification: bool = True,
193+
target_pos: int = 1,
194+
shuffle: bool = True,
195+
seed: int = 1):
196+
self.num_folds = num_folds
197+
self.random_state = np.random.RandomState(seed=seed) if shuffle else None
198+
self.target_pos = target_pos
199+
200+
if stratification:
201+
self.splitter = StratifiedKFold(n_splits=num_folds, shuffle=shuffle, random_state=self.random_state)
202+
else:
203+
self.splitter = KFold(n_splits=num_folds, shuffle=shuffle, random_state=self.random_state)
204+
205+
def split(self, dataset_iterator: DatasetIteratorIF) -> List[DatasetIteratorView]:
206+
targets = np.array([sample[self.target_pos] for sample in dataset_iterator])
207+
folds_indices = [fold[1].tolist() for fold in self.splitter.split(X=np.zeros(len(targets)), y=targets)]
208+
fold_iterators = [DatasetIteratorView(dataset_iterator, fold_indices) for fold_indices in folds_indices]
209+
return fold_iterators
210+
211+
def get_indices(self, dataset_iterator: DatasetIteratorIF) -> List[List[int]]:
212+
folds = self.split(dataset_iterator)
213+
folds_indices = [fold.indices for fold in folds]
214+
return folds_indices

0 commit comments

Comments
 (0)