Skip to content

Commit f38464a

Browse files
Maren PielkaMaren Pielka
authored andcommitted
fixed test and stratified splitter function (finally)
1 parent f3e279c commit f38464a

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/data_stack/dataset/splitter.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,19 @@ def _determine_split_indices(self, dataset_length: int, split_config: Dict, data
110110
# split the training set until the desired number of splits is reached
111111
if "train" in split_config.keys():
112112
train_indices, test_indices, train_targets, test_targets = train_test_split(indices, targets,
113-
train_size = int(len(indices)*split_config["train"]))
113+
train_size = int(len(indices)*split_config["train"]),
114+
stratify=targets)
114115
split_indices.append(train_indices)
115116
else:
116117
logging.error("Training split was not provided")
117118
sys.exit(1)
118-
for split in split_config.keys():
119+
for split in list(split_config.keys())[:-1]:
119120
if split != "train":
120121
train_indices, test_indices, train_targets, test_targets = train_test_split(test_indices, test_targets,
121-
train_size = int(len(indices)*split_config[split]))
122+
train_size = int(len(indices)*split_config[split]),
123+
stratify=test_targets)
122124
split_indices.append(train_indices)
125+
# any remaining indices are added to the last split
123126
split_indices.append(test_indices)
124127
return split_indices
125128

unittests/dataset/test_splitter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def dataset_iterator(self) -> DatasetIteratorIF:
2929

3030
@pytest.fixture
3131
def dataset_iterator_stratifiable(self) -> DatasetIteratorIF:
32-
return SequenceDatasetIterator(dataset_sequences=[list(range(10)), [0,0,0,1,1,0,0,1,0,1]])
32+
return SequenceDatasetIterator(dataset_sequences=[list(range(20)), list(np.ones(8, dtype=int))+
33+
list(np.zeros(12, dtype=int))])
3334

3435
def test_random_splitter(self, ratios: List[int], dataset_iterator: DatasetIteratorIF):
3536
splitter_impl = RandomSplitterImpl(ratios=ratios, seed=100)
@@ -51,9 +52,9 @@ def test_stratification(self, split_config: Dict[str, int], dataset_iterator_str
5152
iterator_splits = splitter.split(dataset_iterator_stratifiable)
5253

5354
# target distribution should be equal among all splits
54-
assert(sum([sample[1] for sample in iterator_splits[0]]) == 2)
55-
assert(sum([sample[1] for sample in iterator_splits[1]]) == 1)
56-
assert(sum([sample[1] for sample in iterator_splits[1]]) == 1)
55+
assert(sum([sample[1] for sample in iterator_splits[0]]) == 4)
56+
assert(sum([sample[1] for sample in iterator_splits[1]]) == 2)
57+
assert(sum([sample[1] for sample in iterator_splits[2]]) == 2)
5758

5859
def test_seeding(self):
5960
ratios = [0.4, 0.6]

0 commit comments

Comments
 (0)