Skip to content

Commit 93039e4

Browse files
authored
Merge pull request #254 from Oxid15/iterators
Iterators
2 parents 8facdcc + 12d96df commit 93039e4

File tree

10 files changed

+80
-41
lines changed

10 files changed

+80
-41
lines changed

cascade/data/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
limitations under the License.
1515
"""
1616

17+
from typing import Union
18+
1719
from .apply_modifier import ApplyModifier
1820
from .bruteforce_cacher import BruteforceCacher
1921
from .composer import Composer
2022
from .concatenator import Concatenator
2123
from .cyclic_sampler import CyclicSampler
2224
from .dataset import (BaseDataset, Dataset, IteratorDataset, IteratorWrapper,
23-
SizedDataset, Wrapper)
25+
SizedDataset, T, Wrapper)
2426
from .filter import Filter, IteratorFilter
2527
from .folder_dataset import FolderDataset
2628
from .functions import dataset, modifier

cascade/data/apply_modifier.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,22 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Any, Callable
17+
from typing import Any, Callable, Iterator
1818

1919
from .dataset import Dataset, T
2020
from .modifier import Modifier
21+
from .utils import DatasetOrIterator
2122

2223

2324
class ApplyModifier(Modifier[T]):
2425
"""
25-
Modifier that applies a function to given dataset's items in each __getitem__ call
26+
Modifier that applies a function to given dataset's items in each __getitem__ call.
27+
28+
Can be applied to Iterators too.
2629
"""
2730

2831
def __init__(
29-
self, dataset: Dataset[T], func: Callable[[T], Any], *args: Any, **kwargs: Any
32+
self, dataset: DatasetOrIterator[T], func: Callable[[T], Any], *args: Any, **kwargs: Any
3033
) -> None:
3134
"""
3235
Parameters
@@ -51,9 +54,12 @@ def __init__(
5154
self._func = func
5255

5356
def __getitem__(self, index: int) -> Any:
54-
item = self._dataset[index]
55-
return self._func(item)
57+
if isinstance(self._dataset, Dataset):
58+
item = self._dataset[index]
59+
return self._func(item)
60+
else:
61+
raise TypeError(f"The underlying dataset is not a Dataset, but {type(self._dataset)}")
5662

57-
# def __iter__(self):
58-
# for item in self._dataset:
59-
# yield self._func(item)
63+
def __iter__(self) -> Iterator[T]:
64+
for item in self._dataset:
65+
yield self._func(item)

cascade/data/dataset.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class IteratorDataset(BaseDataset[T], Iterable[T]):
5656
An abstract class to represent a dataset as
5757
an iterable object
5858
"""
59+
5960
def __iter__(self) -> Iterator[T]:
6061
return super().__iter__()
6162

@@ -75,27 +76,31 @@ class Dataset(BaseDataset[T], Sized):
7576
"""
7677

7778
@abstractmethod
78-
def __getitem__(self, index: Any): ...
79+
def __getitem__(self, index: Any) -> T: ...
7980

8081
@abstractmethod
8182
def __len__(self) -> int: ...
8283

84+
def __iter__(self) -> Iterator[T]:
85+
for i in range(len(self)):
86+
yield self.__getitem__(i)
87+
8388
def get_meta(self) -> Meta:
8489
meta = super().get_meta()
8590
meta[0]["len"] = len(self)
8691
return meta
8792

8893

89-
class IteratorWrapper(BaseDataset[T]):
94+
class IteratorWrapper(IteratorDataset[T]):
9095
"""
91-
Wraps BaseDataset around any Iterable. Does not have map-like interface.
96+
Wraps IteratorDataset around any Iterable. Does not have map-like interface.
9297
"""
9398

9499
def __init__(self, data: Iterable[T], *args: Any, **kwargs: Any) -> None:
95100
super().__init__(*args, **kwargs)
96101
self._data = data
97102

98-
def __iter__(self) -> Generator[T, Any, None]:
103+
def __iter__(self) -> Iterator[T]:
99104
for item in self._data:
100105
yield item
101106

cascade/data/simple_dataloader.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Any, Generator, List, Sequence
17+
from typing import Any, Iterator, List, Sequence
1818

1919
import numpy as np
2020

@@ -52,7 +52,7 @@ def __getitem__(self, index: int) -> List[Any]:
5252
batch.append(item)
5353
return batch
5454

55-
def __iter__(self) -> Generator[T, Any, None]:
55+
def __iter__(self) -> Iterator[T]:
5656
for i in range(len(self)):
5757
yield self[i]
5858

cascade/data/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
"""
1616

1717
from math import floor
18-
from typing import Optional, Tuple
18+
from typing import Optional, Tuple, Union
1919

20-
from .dataset import Dataset, T
20+
from .dataset import Dataset, IteratorDataset, T
2121
from .range_sampler import RangeSampler
2222

23+
DatasetOrIterator = Union[Dataset[T], IteratorDataset[T]]
24+
2325

2426
def split(
2527
ds: Dataset[T], frac: Optional[float] = 0.5, num: Optional[int] = None

cascade/models/model_repo.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import itertools
1515
import os
1616
import shutil
17-
from typing import (Any, Dict, Generator, Iterable, List, Literal, Optional,
18-
Type, Union)
17+
from typing import Any, Dict, Iterable, Iterator, List, Literal, Type, Union
1918

2019
from typing_extensions import deprecated
2120

2221
from ..base import Meta, Traceable, TraceableOnDisk
22+
from ..data import T
2323
from ..version import __version__
2424
from .model import Model
2525
from .model_line import ModelLine
@@ -55,7 +55,7 @@ def __len__(self) -> int:
5555
"""
5656
return len(self._lines)
5757

58-
def __iter__(self) -> Generator[ModelLine, None, None]:
58+
def __iter__(self) -> Iterator[T]:
5959
for line in self._lines:
6060
yield self.__getitem__(line)
6161

@@ -81,7 +81,9 @@ def reload(self) -> None:
8181
pass
8282

8383

84-
@deprecated("cascade.models.SingleLineRepo is deprecated, consider using cascade.repos.SingleLineRepo instead")
84+
@deprecated(
85+
"cascade.models.SingleLineRepo is deprecated, consider using cascade.repos.SingleLineRepo instead"
86+
)
8587
class SingleLineRepo(Repo):
8688
def __init__(
8789
self,
@@ -99,7 +101,9 @@ def __getitem__(self, key: str) -> ModelLine:
99101
if key in self._lines:
100102
return self._line
101103
else:
102-
raise KeyError(f"The only line is {list(self._lines.keys())[0]}, {key} does not exist")
104+
raise KeyError(
105+
f"The only line is {list(self._lines.keys())[0]}, {key} does not exist"
106+
)
103107

104108
def __repr__(self) -> str:
105109
return f"SingleLine in {self._root}"
@@ -111,7 +115,9 @@ def reload(self) -> None:
111115
self._line.reload()
112116

113117

114-
@deprecated("cascade.models.ModelRepo is deprecated, consider using cascade.repos.Repo instead")
118+
@deprecated(
119+
"cascade.models.ModelRepo is deprecated, consider using cascade.repos.Repo instead"
120+
)
115121
class ModelRepo(Repo, TraceableOnDisk):
116122
"""
117123
An interface to manage experiments with several lines of models.
@@ -306,11 +312,16 @@ def load_model_meta(self, model: str) -> Meta:
306312
continue
307313
else:
308314
return meta
309-
raise FileNotFoundError(f"Failed to find the model {model} in the repo at {self._root}")
315+
raise FileNotFoundError(
316+
f"Failed to find the model {model} in the repo at {self._root}"
317+
)
310318

311319
def _update_lines(self) -> None:
312320
for name in sorted(os.listdir(self._root)):
313-
if os.path.isdir(os.path.join(self._root, name)) and name not in self._lines:
321+
if (
322+
os.path.isdir(os.path.join(self._root, name))
323+
and name not in self._lines
324+
):
314325
self._lines[name] = {"args": [], "kwargs": dict()}
315326

316327

@@ -319,7 +330,7 @@ def _update_lines(self) -> None:
319330
" 0.14.0 and will be removed by 0.15.0"
320331
" Use Workspaces instead",
321332
category=DeprecationWarning,
322-
stacklevel=1
333+
stacklevel=1,
323334
)
324335
class ModelRepoConcatenator(Repo):
325336
"""
@@ -350,7 +361,7 @@ def __getitem__(self, key) -> ModelLine:
350361
def __len__(self) -> int:
351362
return sum([len(repo) for repo in self._repos])
352363

353-
def __iter__(self) -> Generator[ModelLine, None, None]:
364+
def __iter__(self) -> Iterator[T]:
354365
# this flattens the list of lines
355366
for line in itertools.chain(*[[line for line in repo] for repo in self._repos]):
356367
yield line

cascade/models/workspace.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
import os
1818
import warnings
19-
from typing import Any, Generator, List, Literal, Optional
19+
from typing import Any, Iterator, List, Literal, Optional
2020

2121
from typing_extensions import deprecated
2222

2323
from ..base import Meta, MetaHandler, MetaIOError, TraceableOnDisk
24+
from ..data import T
2425
from ..models import ModelRepo
2526

2627

@@ -67,7 +68,7 @@ def __getitem__(self, key: str) -> ModelRepo:
6768
def __len__(self) -> int:
6869
return len(self._repo_names)
6970

70-
def __iter__(self) -> Generator[ModelRepo, None, None]:
71+
def __iter__(self) -> Iterator[T]:
7172
for repo in self._repo_names:
7273
yield self.__getitem__(repo)
7374

cascade/tests/data/test_apply_modifier.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,28 @@
1919

2020
import pytest
2121

22-
from cascade.data import ApplyModifier
22+
from cascade.data import ApplyModifier, IteratorWrapper, Wrapper
2323

2424
SCRIPT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
2525
sys.path.append(os.path.dirname(SCRIPT_DIR))
2626

27-
from cascade.data import Wrapper
2827

28+
data = [
29+
([1, 2, 3, 4, 5], lambda x: x * 2),
30+
([1], lambda x: x**2),
31+
([1, 2, -3], lambda x: x),
32+
]
2933

30-
@pytest.mark.parametrize(
31-
"arr, func",
32-
[
33-
([1, 2, 3, 4, 5], lambda x: x * 2),
34-
([1], lambda x: x**2),
35-
([1, 2, -3], lambda x: x),
36-
],
37-
)
34+
35+
@pytest.mark.parametrize("arr, func", data)
3836
def test_apply_modifier(arr, func):
3937
ds = Wrapper(arr)
4038
ds = ApplyModifier(ds, func)
4139
assert list(map(func, arr)) == [item for item in ds]
40+
41+
42+
@pytest.mark.parametrize("arr, func", data)
43+
def test_apply_modifier_iterators(arr, func):
44+
ds = IteratorWrapper(arr)
45+
ds = ApplyModifier(ds, func)
46+
assert list(map(func, arr)) == [item for item in ds]

cascade/tests/data/test_meta.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
3+
def test_meta(dataset):
4+
meta = dataset.get_meta()
5+
assert isinstance(meta, list)
6+
assert len(meta) > 0

cascade/workspaces/workspace.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
import os
1818
import warnings
19-
from typing import Any, Generator, List, Literal, Optional
19+
from typing import Any, Iterator, List, Literal, Optional
2020

2121
from ..base import Meta, MetaHandler, MetaIOError, TraceableOnDisk
22+
from ..data import T
2223
from ..repos.repo import Repo
2324

2425

@@ -63,7 +64,7 @@ def __getitem__(self, key: str) -> Repo:
6364
def __len__(self) -> int:
6465
return len(self._repo_names)
6566

66-
def __iter__(self) -> Generator[Repo, None, None]:
67+
def __iter__(self) -> Iterator[T]:
6768
for repo in self._repo_names:
6869
yield self.__getitem__(repo)
6970

0 commit comments

Comments
 (0)