Skip to content

Commit d128b93

Browse files
Maren PielkaMaren Pielka
authored andcommitted
merged with master
2 parents f38464a + cfd3639 commit d128b93

File tree

6 files changed

+120
-61
lines changed

6 files changed

+120
-61
lines changed

src/data_stack/dataset/reporting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class DatasetIteratorReport:
1919
sample_pos: int
2020
target_pos: int
2121
tag_pos: int
22-
sample_shape: List[int]
2322
target_dist: Dict[Union[str, int], int]
2423
iteration_speed: float
2524
sub_reports: List["DatasetIteratorReport"]
@@ -38,9 +37,9 @@ def generate_report(iterator: InformedDatasetIterator, report_format: ReportForm
3837
target_dist = {k: v for k, v in sorted(Counter([row[meta.target_pos] for row in iterator]).items())}
3938
iteration_speed = DatasetIteratorReportGenerator.measure_iteration_speed(iterator)
4039
# generate report
40+
4141
report = DatasetIteratorReport(meta.identifier, meta.dataset_name, meta.dataset_tag, len(iterator), meta.sample_pos,
42-
meta.target_pos, meta.tag_pos, list(iterator[0][meta.sample_pos].shape), target_dist,
43-
iteration_speed, sub_reports)
42+
meta.target_pos, meta.tag_pos, target_dist, iteration_speed, sub_reports)
4443
# format report
4544
if report_format == DatasetIteratorReportGenerator.ReportFormat.JSON:
4645
return DatasetIteratorReportGenerator._to_json(report)

src/data_stack/dataset/splitter.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,21 @@ def __init__(self,
158158

159159
def split(self, dataset_iterator: DatasetIteratorIF) -> Tuple[List[DatasetIteratorIF], List[List[DatasetIteratorIF]]]:
160160
# create outer loop folds
161-
targets = [sample[self.target_pos] for sample in dataset_iterator]
162-
folds_indices = [fold[1] for fold in self.outer_splitter.split(X=np.zeros(len(targets)), y=targets)]
163-
outer_folds = [DatasetIteratorView(dataset_iterator, fold_indices) for fold_indices in folds_indices]
161+
targets = np.array([sample[self.target_pos] for sample in dataset_iterator])
162+
outer_folds_indices = [fold[1] for fold in self.outer_splitter.split(X=np.zeros(len(targets)), y=targets)]
163+
outer_fold_iterators = [DatasetIteratorView(dataset_iterator, fold_indices) for fold_indices in outer_folds_indices]
164164
# create inner loop folds
165-
inner_folds_list = [] # contains [inner folds of outer_fold_1, inner folds of outer_fold_2 ...]
166-
for iterator in outer_folds:
167-
targets = [sample[self.target_pos] for sample in iterator]
168-
folds_indices = [fold[1] for fold in self.inner_splitter.split(X=np.zeros(len(targets)), y=targets)]
169-
inner_folds = [DatasetIteratorView(iterator, fold_indices) for fold_indices in folds_indices]
170-
inner_folds_list.append(inner_folds)
171-
return outer_folds, inner_folds_list
165+
inner_folds_iterators_list = [] # contains [inner folds of outer_fold_1, inner folds of outer_fold_2 ...]
166+
for outer_fold_id in range(len(outer_fold_iterators)):
167+
# concat the indices of the splits which belong to the train splits
168+
train_split_ids = [i for i in range(len(outer_folds_indices)) if i != outer_fold_id]
169+
outer_train_fold_indices = np.array([indice for i in train_split_ids for indice in outer_folds_indices[i]])
170+
inner_targets = targets[outer_train_fold_indices]
171+
inner_folds_indices = [outer_train_fold_indices[inner_fold[1]]
172+
for inner_fold in self.inner_splitter.split(X=np.zeros(len(inner_targets)), y=inner_targets)]
173+
inner_folds = [DatasetIteratorView(dataset_iterator, fold_indices) for fold_indices in inner_folds_indices]
174+
inner_folds_iterators_list.append(inner_folds)
175+
return outer_fold_iterators, inner_folds_iterators_list
172176

173177
def get_indices(self, dataset_iterator: DatasetIteratorIF) -> Tuple[List[List[int]], List[List[int]]]:
174178
outer_folds, inner_folds_list = self.split(dataset_iterator)

src/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name='datastack',
8-
version='0.0.9',
8+
version='0.0.10',
99
author='Max Luebbering',
1010
description="DataStack, a stream based solution for machine learning dataset retrieval and storage",
1111
long_description=long_description,

unittests/dataset/test_reporting.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import pytest
2-
from typing import List
3-
from data_stack.mnist.factory import MNISTFactory
42
from data_stack.io.storage_connectors import StorageConnector, StorageConnectorFactory
53
from data_stack.dataset.reporting import DatasetIteratorReportGenerator
64
import tempfile
75
import shutil
8-
from data_stack.dataset.iterator import InformedDatasetIterator
9-
from data_stack.dataset.meta import MetaFactory
106
from data_stack.dataset.factory import InformedDatasetFactory
7+
from data_stack.dataset.meta import DatasetMeta, MetaFactory
8+
from data_stack.dataset.iterator import DatasetIteratorIF, SequenceDatasetIterator, InformedDatasetIterator
119

1210

1311
class TestReporting:
@@ -22,36 +20,38 @@ def tmp_folder_path(self) -> str:
2220
def storage_connector(self, tmp_folder_path: str) -> StorageConnector:
2321
return StorageConnectorFactory.get_file_storage_connector(tmp_folder_path)
2422

25-
@pytest.fixture(scope="session")
26-
def mnist_factory(self, storage_connector) -> List[int]:
27-
mnist_factory = MNISTFactory(storage_connector)
28-
return mnist_factory
29-
30-
def test_plain_iterator_reporting(self, mnist_factory):
31-
iterator, iterator_meta = mnist_factory.get_dataset_iterator(config={"split": "train"})
32-
dataset_meta = MetaFactory.get_dataset_meta(identifier="id x", dataset_name="MNIST",
33-
dataset_tag="train", iterator_meta=iterator_meta)
34-
35-
informed_iterator = InformedDatasetIterator(iterator, dataset_meta)
36-
report = DatasetIteratorReportGenerator.generate_report(informed_iterator)
23+
# @pytest.fixture(scope="session")
24+
# def mnist_factory(self, storage_connector) -> List[int]:
25+
# mnist_factory = MNISTFactory(storage_connector)
26+
# return mnist_factory
27+
28+
@pytest.fixture
29+
def dataset_meta(self) -> DatasetMeta:
30+
iterator_meta = MetaFactory.get_iterator_meta(sample_pos=0, target_pos=1, tag_pos=2)
31+
return MetaFactory.get_dataset_meta(identifier="identifier_1",
32+
dataset_name="TEST DATASET",
33+
dataset_tag="train",
34+
iterator_meta=iterator_meta)
35+
36+
@pytest.fixture
37+
def dataset_iterator(self) -> DatasetIteratorIF:
38+
targets = [j for i in range(10) for j in range(9)] + [10]*1000
39+
samples = [0]*len(targets)
40+
return SequenceDatasetIterator(dataset_sequences=[samples, targets])
41+
42+
@pytest.fixture
43+
def informed_dataset_iterator(self, dataset_iterator, dataset_meta) -> DatasetIteratorIF:
44+
return InformedDatasetFactory.get_dataset_iterator(dataset_iterator, dataset_meta)
45+
46+
def test_plain_iterator_reporting(self, informed_dataset_iterator):
47+
report = DatasetIteratorReportGenerator.generate_report(informed_dataset_iterator)
3748
print(report)
38-
assert report.length == 60000 and not report.sub_reports
39-
40-
def test_combined_iterator_reporting(self, mnist_factory):
41-
42-
iterator_train, iterator_train_meta = mnist_factory.get_dataset_iterator(config={"split": "train"})
43-
iterator_test, iterator_test_meta = mnist_factory.get_dataset_iterator(config={"split": "test"})
44-
meta_train = MetaFactory.get_dataset_meta(identifier="id x", dataset_name="MNIST",
45-
dataset_tag="train", iterator_meta=iterator_train_meta)
46-
meta_test = MetaFactory.get_dataset_meta(identifier="id x", dataset_name="MNIST",
47-
dataset_tag="train", iterator_meta=iterator_test_meta)
48-
49-
informed_iterator_train = InformedDatasetFactory.get_dataset_iterator(iterator_train, meta_train)
50-
informed_iterator_test = InformedDatasetFactory.get_dataset_iterator(iterator_test, meta_test)
51-
52-
meta_combined = MetaFactory.get_dataset_meta_from_existing(informed_iterator_train.dataset_meta, dataset_tag="full")
49+
assert report.length == 1090 and not report.sub_reports
5350

54-
iterator = InformedDatasetFactory.get_combined_dataset_iterator([informed_iterator_train, informed_iterator_test], meta_combined)
51+
def test_combined_iterator_reporting(self, informed_dataset_iterator):
52+
meta_combined = MetaFactory.get_dataset_meta_from_existing(informed_dataset_iterator.dataset_meta, dataset_tag="full")
53+
iterator = InformedDatasetFactory.get_combined_dataset_iterator(
54+
[informed_dataset_iterator, informed_dataset_iterator], meta_combined)
5555
report = DatasetIteratorReportGenerator.generate_report(iterator)
56-
assert report.length == 70000 and report.sub_reports[0].length == 60000 and report.sub_reports[1].length == 10000
56+
assert report.length == 2180 and report.sub_reports[0].length == 1090 and report.sub_reports[1].length == 1090
5757
assert not report.sub_reports[0].sub_reports and not report.sub_reports[1].sub_reports

unittests/dataset/test_splitter.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import pytest
22
import numpy as np
3+
import collections
34
from data_stack.dataset.iterator import DatasetIteratorIF, SequenceDatasetIterator
45
from typing import List, Dict
5-
from data_stack.dataset.splitter import RandomSplitterImpl, StratifiedSplitterImpl, Splitter
6+
from data_stack.dataset.splitter import RandomSplitterImpl, StratifiedSplitterImpl, Splitter, NestedCVSplitterImpl
67
from data_stack.dataset.meta import DatasetMeta, MetaFactory
78

89

@@ -27,6 +28,12 @@ def dataset_meta(self) -> DatasetMeta:
2728
def dataset_iterator(self) -> DatasetIteratorIF:
2829
return SequenceDatasetIterator(dataset_sequences=[list(range(10)), list(range(10))])
2930

31+
@pytest.fixture
32+
def big_dataset_iterator(self) -> DatasetIteratorIF:
33+
targets = [j for i in range(10) for j in range(9)] + [10] * 1000
34+
samples = [0] * len(targets)
35+
return SequenceDatasetIterator(dataset_sequences=[samples, targets])
36+
3037
@pytest.fixture
3138
def dataset_iterator_stratifiable(self) -> DatasetIteratorIF:
3239
return SequenceDatasetIterator(dataset_sequences=[list(range(20)), list(np.ones(8, dtype=int))+
@@ -56,6 +63,55 @@ def test_stratification(self, split_config: Dict[str, int], dataset_iterator_str
5663
assert(sum([sample[1] for sample in iterator_splits[1]]) == 2)
5764
assert(sum([sample[1] for sample in iterator_splits[2]]) == 2)
5865

66+
@pytest.mark.parametrize(
67+
"num_outer_loop_folds, num_inner_loop_folds, inner_stratification, outer_stratification, shuffle",
68+
[(5, 2, True, True, False), (5, 2, True, True, True), (5, 2, False, False, True),
69+
(5, 2, False, False, False)],
70+
)
71+
def test_nested_cv_splitter(self, num_outer_loop_folds: int, num_inner_loop_folds: int,
72+
inner_stratification: bool,
73+
outer_stratification: bool, shuffle: bool, big_dataset_iterator: DatasetIteratorIF):
74+
splitter_impl = NestedCVSplitterImpl(num_outer_loop_folds=num_outer_loop_folds,
75+
num_inner_loop_folds=num_inner_loop_folds,
76+
inner_stratification=inner_stratification,
77+
outer_stratification=outer_stratification,
78+
shuffle=shuffle)
79+
splitter = Splitter(splitter_impl)
80+
outer_folds, inner_folds = splitter.split(big_dataset_iterator)
81+
# make sure that outer folds have no intersection
82+
for i in range(len(outer_folds)):
83+
for j in range(len(outer_folds)):
84+
if i != j:
85+
# makes sure there is no intersection
86+
assert len(set(outer_folds[i].indices).intersection(set(outer_folds[j].indices))) == 0
87+
# make sure that inner folds have no intersection
88+
for i in range(len(inner_folds)):
89+
for j in range(len(inner_folds[i])):
90+
for k in range(len(inner_folds[i])):
91+
if j != k:
92+
# makes sure there is no intersection
93+
assert len(set(inner_folds[i][j].indices).intersection(set(inner_folds[i][k].indices))) == 0
94+
# test stratification
95+
if outer_stratification:
96+
class_counts = dict(collections.Counter([t for _, t in big_dataset_iterator]))
97+
class_counts_per_fold = {target_class: int(count / num_outer_loop_folds) for target_class, count in
98+
class_counts.items()}
99+
for fold in outer_folds:
100+
fold_class_counts = dict(collections.Counter([t for _, t in fold]))
101+
for key in list(class_counts_per_fold.keys()) + list(fold_class_counts.keys()):
102+
assert class_counts_per_fold[key] == fold_class_counts[key]
103+
104+
if inner_stratification:
105+
for i in range(len(inner_folds)):
106+
class_counts = dict(collections.Counter([t for _, t in outer_folds[i]]))
107+
class_counts_per_fold = {
108+
target_class: int(count * (num_outer_loop_folds - 1) / num_inner_loop_folds) for
109+
target_class, count in class_counts.items()}
110+
for fold in inner_folds[i]:
111+
fold_class_counts = dict(collections.Counter([t for _, t in fold]))
112+
for key in list(class_counts_per_fold.keys()) + list(fold_class_counts.keys()):
113+
assert class_counts_per_fold[key] == fold_class_counts[key]
114+
59115
def test_seeding(self):
60116
ratios = [0.4, 0.6]
61117
dataset_length = 100

unittests/io/test_retriever.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,19 @@ def file_retriever(self, storage_connector: StorageConnector):
6161
def http_retriever_impl(self, storage_connector: StorageConnector):
6262
return HTTPRetrieverImpl(storage_connector)
6363

64-
def test_http_retriever_retrieve(self, http_retriever: Retriever, http_retrieval_job: ResourceDefinition):
65-
http_retriever.retrieve([http_retrieval_job])
66-
storage_connector = http_retriever.retriever_impl.storage_connector
67-
resource = storage_connector.get_resource(http_retrieval_job.identifier)
68-
assert TestBaseRetriever.get_md5(resource) == http_retrieval_job.md5_sum
69-
70-
def test_http_retriever_impl_download_file(self, http_retriever_impl: HTTPRetrieverImpl, http_retrieval_job: ResourceDefinition, tmp_folder_path: str):
71-
file_path = http_retriever_impl._download_file(url=http_retrieval_job.source,
72-
dest_folder=tmp_folder_path,
73-
md5=http_retrieval_job.md5_sum)
74-
with open(file_path, "rb") as fd:
75-
md5_sum = TestBaseRetriever.get_md5(fd)
76-
return md5_sum == http_retrieval_job.md5_sum
64+
# def test_http_retriever_retrieve(self, http_retriever: Retriever, http_retrieval_job: ResourceDefinition):
65+
# http_retriever.retrieve([http_retrieval_job])
66+
# storage_connector = http_retriever.retriever_impl.storage_connector
67+
# resource = storage_connector.get_resource(http_retrieval_job.identifier)
68+
# assert TestBaseRetriever.get_md5(resource) == http_retrieval_job.md5_sum
69+
70+
# def test_http_retriever_impl_download_file(self, http_retriever_impl: HTTPRetrieverImpl, http_retrieval_job: ResourceDefinition, tmp_folder_path: str):
71+
# file_path = http_retriever_impl._download_file(url=http_retrieval_job.source,
72+
# dest_folder=tmp_folder_path,
73+
# md5=http_retrieval_job.md5_sum)
74+
# with open(file_path, "rb") as fd:
75+
# md5_sum = TestBaseRetriever.get_md5(fd)
76+
# return md5_sum == http_retrieval_job.md5_sum
7777

7878
def test_file_retriever_retrieve(self, file_retriever: Retriever, file_retrieval_job: ResourceDefinition):
7979
file_retriever.retrieve([file_retrieval_job])

0 commit comments

Comments
 (0)