Skip to content

Schema datasets #242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 12, 2024
Merged
10 changes: 10 additions & 0 deletions cascade/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@
from .composer import Composer
from .concatenator import Concatenator
from .cyclic_sampler import CyclicSampler
from .dataset import (
BaseDataset,
Dataset,
IteratorDataset,
IteratorWrapper,
SizedDataset,
Wrapper,
)
from .folder_dataset import FolderDataset
from .functions import dataset, modifier
from .modifier import BaseModifier, IteratorModifier, Modifier, Sampler
from .pickler import Pickler
from .random_sampler import RandomSampler
from .range_sampler import RangeSampler
from .schema import SchemaModifier
from .sequential_cacher import SequentialCacher
from .simple_dataloader import SimpleDataloader
from .utils import split
Expand Down
26 changes: 8 additions & 18 deletions cascade/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,10 @@
"""

import warnings
from typing import (
Any,
Generator,
Generic,
Iterable,
Sequence,
Sized,
TypeVar,
)

from ..base import PipeMeta, Traceable, raise_not_implemented
from abc import abstractmethod
from typing import Any, Generator, Generic, Iterable, Sequence, Sized, TypeVar

from ..base import PipeMeta, Traceable

T = TypeVar("T", covariant=True)

Expand All @@ -39,11 +32,8 @@ class BaseDataset(Generic[T], Traceable):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

def __getitem__(self, index: Any) -> T:
"""
Abstract method - should be defined in every successor
"""
raise_not_implemented("cascade.data.Dataset", "__getitem__")
@abstractmethod
def __getitem__(self, index: Any) -> T:...

def get_meta(self) -> PipeMeta:
"""
Expand Down Expand Up @@ -80,8 +70,8 @@ class Dataset(BaseDataset[T], Sized):
cascade.data.Iterator
"""

def __len__(self) -> int:
raise_not_implemented("cascade.data.Dataset", "__len__")
@abstractmethod
def __len__(self) -> int: ...

def get_meta(self) -> PipeMeta:
meta = super().get_meta()
Expand Down
104 changes: 104 additions & 0 deletions cascade/data/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
Copyright 2022-2024 Ilia Moiseev

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Any, Optional

from .dataset import Dataset
from .modifier import Modifier
from .validation import SchemaValidator


class SchemaModifier(Modifier):
"""
Data validation modifier

When `self._dataset` is called and has
self.in_schema defined, wraps `self._dataset` into
validator, which is another `Modifier` that
checks the output of `__getitem__` of the
dataset that was wrapped.

In the end it will look like this:
If `in_schema` is not None:
`dataset = SchemaModifier(ValidationWrapper(dataset))`
If `in_schema` is None:
`dataset = SchemaModifier(dataset)`

How to use it:
1. Define pydantic schema of input

```python
import pydantic

class AnnotImage(pydantic.BaseModel):
image: List[List[List[float]]]
segments: List[List[int]]
bboxes: List[Tuple[int, int, int, int]]
```

2. Use schema as `in_schema`

```python
from cascade.data import SchemaModifier

class ImageModifier(SchemaModifier):
in_schema = AnnotImage
```

3. Create a regular `Modifier` by
subclassing ImageModifier.

```python
class IDoNothing(ImageModifier):
def __getitem__(self, idx):
item = self._dataset[idx]
return item
```

4. That's all. Schema check will be held
automatically every time `self._dataset[idx]` is
accessed. If it is not `AnnotImage`, cascade.data.ValidationError
will be raised.

"""
in_schema: Optional[Any] = None

def __getattribute__(self, __name: str) -> Any:
if __name == "_dataset" and self.in_schema is not None:
return ValidationWrapper(super().__getattribute__(__name), self.in_schema)
if __name == "get_meta":
def get_meta(self):
meta = super().get_meta()
if self.in_schema:
meta[0]["in_schema"] = self.in_schema.model_json_schema()
return meta

return lambda: get_meta(self)

return super().__getattribute__(__name)


class ValidationWrapper(Modifier):
def __init__(
self, dataset: Dataset, schema: Any, *args: Any, **kwargs: Any
) -> None:
self.validator = SchemaValidator(schema)
super().__init__(dataset, *args, **kwargs)

def __getitem__(self, index: Any):
item = super().__getitem__(index)
self.validator(item)
return item
108 changes: 84 additions & 24 deletions cascade/data/validation.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,128 @@
"""
Copyright 2022-2024 Ilia Moiseev

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import inspect
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Dict, Literal, Tuple
from typing import Any, Callable, Dict, Literal, Tuple, Union

SupportedProviders = Literal["pydantic"]

TypeDict = Dict[str, Tuple[Any, Any]]


class ValidationError(Exception):
pass


class ValidationProvider:
def __init__(self, types: Dict[str, Tuple[Any, Any]]) -> None:
self._types = types
def __init__(self, schema: Any) -> None:
self._schema = schema

def __call__(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError()


class PydanticValidator(ValidationProvider):
def __init__(self, types: Dict[str, Tuple[Any, Any]]) -> None:
super().__init__(types)
def __init__(self, schema: Any) -> None:
super().__init__(schema)

try:
from pydantic import ValidationError, create_model
from pydantic import BaseModel, ValidationError
except ImportError as e:
raise ImportError(
"Cannot import `pydantic` - it is optional dependency for general type checking"
) from e
else:
self._base_model_cls = BaseModel
self._exc_type = ValidationError
self._model = create_model("pydantic_validator", **types) # type: ignore

def __call__(self, *args: Any, **kwargs: Any) -> None:
from_args = dict()
for name, arg in zip(self._types, args):
from_args[name] = arg
if (
len(args) == 1
and len(kwargs) == 0
and (isinstance(args[0], dict) or isinstance(args[0], self._base_model_cls))
):
try:
self._schema.model_validate(args[0])
except self._exc_type as e:
raise ValidationError() from e
else:
from_args = dict()
for name, arg in zip(self._schema.model_fields, args):
from_args[name] = arg

try:
self._model(**from_args, **kwargs)
except self._exc_type as e:
raise ValidationError() from e
try:
self._schema(**from_args, **kwargs)
except self._exc_type as e:
raise ValidationError() from e


class Validator:
def __init__(self, types: Dict[str, Tuple[Any, Any]]) -> None:
providers = {"pydantic": PydanticValidator}
providers = {"pydantic": PydanticValidator}
_validators = []

def __call__(self, *args: Any, **kwargs: Any) -> Any:
for validator in self._validators:
validator(*args, **kwargs)


class SchemaValidator(Validator):
def __init__(self, schema: Any) -> None:
name = self._resolve_validator(schema)
self._validators.append(self.providers[name](schema))

def _resolve_validator(self, *args: Any, **kwargs: Any) -> SupportedProviders:
return "pydantic"


class SchemaFactory:
@classmethod
def build(cls, types: TypeDict, provider: SupportedProviders) -> Any:
if provider == "pydantic":
try:
from pydantic import create_model
except ImportError as e:
raise ImportError(
"Cannot import `pydantic` - it is optional dependency"
" for general type checking"
"\nYou can install it with `pip install 'pydantic==2.6.4'`"
) from e
else:
return create_model(
"pydantic_validator", __config__=dict(arbitrary_types_allowed=True), **types
) # type: ignore


class TypesValidator(Validator):
def __init__(self, types: TypeDict) -> None:
super().__init__()
provider_to_args = defaultdict(dict)
for name in types:
provider = self._resolve_validator(types[name][0])
provider_to_args[provider][name] = types[name]

self._validators = []
for name in provider_to_args:
validator = providers[name](provider_to_args[name])
for provider in provider_to_args:
schema = SchemaFactory.build(provider_to_args[provider], provider)
validator = self.providers[provider](schema)
self._validators.append(validator)

def _resolve_validator(self, type: Any) -> SupportedProviders:
return "pydantic"

def __call__(self, *args: Any, **kwargs: Any) -> Any:
for validator in self._validators:
validator(*args, **kwargs)


def validate_in(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Expand Down Expand Up @@ -92,7 +152,7 @@ def validate_in(f: Callable[..., Any]) -> Callable[..., Any]:
)
for key in sig.parameters
}
v = Validator(args)
v = TypesValidator(args)

@wraps(f)
def wrapper(*args: Any, **kwargs: Any):
Expand Down
3 changes: 3 additions & 0 deletions cascade/meta/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Callable, Dict, List, NoReturn, Union

from tqdm import tqdm
from typing_extensions import deprecated

from ..data.dataset import BaseDataset, T
from ..data.modifier import Modifier
Expand All @@ -37,6 +38,8 @@ class Validator(Modifier[T]):
Base class for validators. Defines basic `__init__` structure
"""

@deprecated("Whole cascade.meta.validation is deprecated since 0.14.0 and is planned to"
" be removed in 0.15.0. Use cascade.data.SchemaDataset instead")
def __init__(
self,
dataset: BaseDataset[T],
Expand Down
2 changes: 1 addition & 1 deletion cascade/tests/data/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2022-2023 Ilia Moiseev
Copyright 2022-2024 Ilia Moiseev

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
6 changes: 3 additions & 3 deletions cascade/tests/data/test_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MODULE_PATH = os.path.dirname(os.path.abspath(os.path.dirname(__file__)))
sys.path.append(os.path.dirname(MODULE_PATH))

from cascade.data import Modifier, Wrapper, IteratorWrapper, ItModifier
from cascade.data import IteratorModifier, IteratorWrapper, Modifier, Wrapper


def test_iter_of_modifier():
Expand All @@ -39,9 +39,9 @@ def test_iter_of_modifier():
assert [1, 2, 3, 4, 5] == result2


def test_iter_of_itmodifier():
def test_iter_of_IteratorModifier():
d = IteratorWrapper([1, 2, 3, 4, 5])
m = ItModifier(d)
m = IteratorModifier(d)

result1 = []
for item in d:
Expand Down
Loading