Skip to content

Iterators #254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cascade/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
limitations under the License.
"""

from typing import Union

from .apply_modifier import ApplyModifier
from .bruteforce_cacher import BruteforceCacher
from .composer import Composer
from .concatenator import Concatenator
from .cyclic_sampler import CyclicSampler
from .dataset import (BaseDataset, Dataset, IteratorDataset, IteratorWrapper,
SizedDataset, Wrapper)
SizedDataset, T, Wrapper)
from .filter import Filter, IteratorFilter
from .folder_dataset import FolderDataset
from .functions import dataset, modifier
Expand Down
22 changes: 14 additions & 8 deletions cascade/data/apply_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@
limitations under the License.
"""

from typing import Any, Callable
from typing import Any, Callable, Iterator

from .dataset import Dataset, T
from .modifier import Modifier
from .utils import DatasetOrIterator


class ApplyModifier(Modifier[T]):
"""
Modifier that applies a function to given dataset's items in each __getitem__ call
Modifier that applies a function to given dataset's items in each __getitem__ call.

Can be applied to Iterators too.
"""

def __init__(
self, dataset: Dataset[T], func: Callable[[T], Any], *args: Any, **kwargs: Any
self, dataset: DatasetOrIterator[T], func: Callable[[T], Any], *args: Any, **kwargs: Any
) -> None:
"""
Parameters
Expand All @@ -51,9 +54,12 @@ def __init__(
self._func = func

def __getitem__(self, index: int) -> Any:
item = self._dataset[index]
return self._func(item)
if isinstance(self._dataset, Dataset):
item = self._dataset[index]
return self._func(item)
else:
raise TypeError(f"The underlying dataset is not a Dataset, but {type(self._dataset)}")

# def __iter__(self):
# for item in self._dataset:
# yield self._func(item)
def __iter__(self) -> Iterator[T]:
for item in self._dataset:
yield self._func(item)
13 changes: 9 additions & 4 deletions cascade/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class IteratorDataset(BaseDataset[T], Iterable[T]):
An abstract class to represent a dataset as
an iterable object
"""

def __iter__(self) -> Iterator[T]:
return super().__iter__()

Expand All @@ -75,27 +76,31 @@ class Dataset(BaseDataset[T], Sized):
"""

@abstractmethod
def __getitem__(self, index: Any): ...
def __getitem__(self, index: Any) -> T: ...

@abstractmethod
def __len__(self) -> int: ...

def __iter__(self) -> Iterator[T]:
for i in range(len(self)):
yield self.__getitem__(i)

def get_meta(self) -> Meta:
meta = super().get_meta()
meta[0]["len"] = len(self)
return meta


class IteratorWrapper(BaseDataset[T]):
class IteratorWrapper(IteratorDataset[T]):
"""
Wraps BaseDataset around any Iterable. Does not have map-like interface.
Wraps IteratorDataset around any Iterable. Does not have map-like interface.
"""

def __init__(self, data: Iterable[T], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._data = data

def __iter__(self) -> Generator[T, Any, None]:
def __iter__(self) -> Iterator[T]:
for item in self._data:
yield item

Expand Down
4 changes: 2 additions & 2 deletions cascade/data/simple_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import Any, Generator, List, Sequence
from typing import Any, Iterator, List, Sequence

import numpy as np

Expand Down Expand Up @@ -52,7 +52,7 @@ def __getitem__(self, index: int) -> List[Any]:
batch.append(item)
return batch

def __iter__(self) -> Generator[T, Any, None]:
def __iter__(self) -> Iterator[T]:
for i in range(len(self)):
yield self[i]

Expand Down
6 changes: 4 additions & 2 deletions cascade/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
"""

from math import floor
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

from .dataset import Dataset, T
from .dataset import Dataset, IteratorDataset, T
from .range_sampler import RangeSampler

DatasetOrIterator = Union[Dataset[T], IteratorDataset[T]]


def split(
ds: Dataset[T], frac: Optional[float] = 0.5, num: Optional[int] = None
Expand Down
31 changes: 21 additions & 10 deletions cascade/models/model_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import itertools
import os
import shutil
from typing import (Any, Dict, Generator, Iterable, List, Literal, Optional,
Type, Union)
from typing import Any, Dict, Iterable, Iterator, List, Literal, Type, Union

from typing_extensions import deprecated

from ..base import Meta, Traceable, TraceableOnDisk
from ..data import T
from ..version import __version__
from .model import Model
from .model_line import ModelLine
Expand Down Expand Up @@ -55,7 +55,7 @@ def __len__(self) -> int:
"""
return len(self._lines)

def __iter__(self) -> Generator[ModelLine, None, None]:
def __iter__(self) -> Iterator[T]:
for line in self._lines:
yield self.__getitem__(line)

Expand All @@ -81,7 +81,9 @@ def reload(self) -> None:
pass


@deprecated("cascade.models.SingleLineRepo is deprecated, consider using cascade.repos.SingleLineRepo instead")
@deprecated(
"cascade.models.SingleLineRepo is deprecated, consider using cascade.repos.SingleLineRepo instead"
)
class SingleLineRepo(Repo):
def __init__(
self,
Expand All @@ -99,7 +101,9 @@ def __getitem__(self, key: str) -> ModelLine:
if key in self._lines:
return self._line
else:
raise KeyError(f"The only line is {list(self._lines.keys())[0]}, {key} does not exist")
raise KeyError(
f"The only line is {list(self._lines.keys())[0]}, {key} does not exist"
)

def __repr__(self) -> str:
return f"SingleLine in {self._root}"
Expand All @@ -111,7 +115,9 @@ def reload(self) -> None:
self._line.reload()


@deprecated("cascade.models.ModelRepo is deprecated, consider using cascade.repos.Repo instead")
@deprecated(
"cascade.models.ModelRepo is deprecated, consider using cascade.repos.Repo instead"
)
class ModelRepo(Repo, TraceableOnDisk):
"""
An interface to manage experiments with several lines of models.
Expand Down Expand Up @@ -306,11 +312,16 @@ def load_model_meta(self, model: str) -> Meta:
continue
else:
return meta
raise FileNotFoundError(f"Failed to find the model {model} in the repo at {self._root}")
raise FileNotFoundError(
f"Failed to find the model {model} in the repo at {self._root}"
)

def _update_lines(self) -> None:
for name in sorted(os.listdir(self._root)):
if os.path.isdir(os.path.join(self._root, name)) and name not in self._lines:
if (
os.path.isdir(os.path.join(self._root, name))
and name not in self._lines
):
self._lines[name] = {"args": [], "kwargs": dict()}


Expand All @@ -319,7 +330,7 @@ def _update_lines(self) -> None:
" 0.14.0 and will be removed by 0.15.0"
" Use Workspaces instead",
category=DeprecationWarning,
stacklevel=1
stacklevel=1,
)
class ModelRepoConcatenator(Repo):
"""
Expand Down Expand Up @@ -350,7 +361,7 @@ def __getitem__(self, key) -> ModelLine:
def __len__(self) -> int:
return sum([len(repo) for repo in self._repos])

def __iter__(self) -> Generator[ModelLine, None, None]:
def __iter__(self) -> Iterator[T]:
# this flattens the list of lines
for line in itertools.chain(*[[line for line in repo] for repo in self._repos]):
yield line
Expand Down
5 changes: 3 additions & 2 deletions cascade/models/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

import os
import warnings
from typing import Any, Generator, List, Literal, Optional
from typing import Any, Iterator, List, Literal, Optional

from typing_extensions import deprecated

from ..base import Meta, MetaHandler, MetaIOError, TraceableOnDisk
from ..data import T
from ..models import ModelRepo


Expand Down Expand Up @@ -67,7 +68,7 @@ def __getitem__(self, key: str) -> ModelRepo:
def __len__(self) -> int:
return len(self._repo_names)

def __iter__(self) -> Generator[ModelRepo, None, None]:
def __iter__(self) -> Iterator[T]:
for repo in self._repo_names:
yield self.__getitem__(repo)

Expand Down
25 changes: 15 additions & 10 deletions cascade/tests/data/test_apply_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,28 @@

import pytest

from cascade.data import ApplyModifier
from cascade.data import ApplyModifier, IteratorWrapper, Wrapper

SCRIPT_DIR = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(SCRIPT_DIR))

from cascade.data import Wrapper

data = [
([1, 2, 3, 4, 5], lambda x: x * 2),
([1], lambda x: x**2),
([1, 2, -3], lambda x: x),
]

@pytest.mark.parametrize(
"arr, func",
[
([1, 2, 3, 4, 5], lambda x: x * 2),
([1], lambda x: x**2),
([1, 2, -3], lambda x: x),
],
)

@pytest.mark.parametrize("arr, func", data)
def test_apply_modifier(arr, func):
ds = Wrapper(arr)
ds = ApplyModifier(ds, func)
assert list(map(func, arr)) == [item for item in ds]


@pytest.mark.parametrize("arr, func", data)
def test_apply_modifier_iterators(arr, func):
ds = IteratorWrapper(arr)
ds = ApplyModifier(ds, func)
assert list(map(func, arr)) == [item for item in ds]
6 changes: 6 additions & 0 deletions cascade/tests/data/test_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@


def test_meta(dataset):
meta = dataset.get_meta()
assert isinstance(meta, list)
assert len(meta) > 0
5 changes: 3 additions & 2 deletions cascade/workspaces/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@

import os
import warnings
from typing import Any, Generator, List, Literal, Optional
from typing import Any, Iterator, List, Literal, Optional

from ..base import Meta, MetaHandler, MetaIOError, TraceableOnDisk
from ..data import T
from ..repos.repo import Repo


Expand Down Expand Up @@ -63,7 +64,7 @@ def __getitem__(self, key: str) -> Repo:
def __len__(self) -> int:
return len(self._repo_names)

def __iter__(self) -> Generator[Repo, None, None]:
def __iter__(self) -> Iterator[T]:
for repo in self._repo_names:
yield self.__getitem__(repo)

Expand Down