Skip to content

Commit 2890fe3

Browse files
authored
Merge pull request #181 from Oxid15/utils_restructure
Utils restructure
2 parents 1a329df + a07aae0 commit 2890fe3

25 files changed

+306
-184
lines changed

cascade/utils/baselines/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Copyright 2022-2023 Ilia Moiseev
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .constant_baseline import ConstantBaseline

cascade/utils/baselines.py renamed to cascade/utils/baselines/constant_baseline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121

22-
from ..models import BasicModel
22+
from ...models import BasicModel
2323

2424
Number = Union[int, float, complex, np.number]
2525

cascade/utils/nlp/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Copyright 2022-2023 Ilia Moiseev
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .text_classification_folder import TextClassificationFolder

cascade/utils/text_classification_dataset.py renamed to cascade/utils/nlp/text_classification_folder.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919

2020
import numpy as np
2121

22-
from ..base import PipeMeta
23-
from ..data import Dataset
22+
from ...base import PipeMeta
23+
from ...data import Dataset
2424

2525

26-
class TextClassificationDataset(Dataset):
26+
class TextClassificationFolder(Dataset):
2727
"""
2828
Dataset to simplify loading of data for text classification.
2929
Texts of different classes should be placed in different folders.
3030
"""
3131

32+
# TODO: can be implemented to be ClassificationFolder and share this functionality with images?
3233
def __init__(
3334
self, path: str, encoding: str = "utf-8", *args: Any, **kwargs: Any
3435
) -> None:

cascade/utils/numpy_wrapper.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,4 @@ class NumpyWrapper(Wrapper):
2828
"""
2929

3030
def __init__(self, path: str, *args: Any, **kwargs: Any) -> None:
31-
self._path = path
32-
super().__init__(np.load(path), *args, **kwargs)
33-
34-
def get_meta(self) -> PipeMeta:
35-
meta = super().get_meta()
36-
meta[0]["root"] = self._path
37-
return meta
31+
raise ImportError("NumpyWrapper was removed since 0.12.0, consider using older version")

cascade/utils/pandera/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Copyright 2022-2023 Ilia Moiseev
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .pa_schema_validator import PaSchemaValidator

cascade/utils/pa_schema_validator.py renamed to cascade/utils/pandera/pa_schema_validator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
import pandera.io as paio
2020
from pandera.errors import SchemaError
2121

22-
from ..meta import AggregateValidator, DataValidationException
23-
from .tables import TableDataset
22+
from ...meta import AggregateValidator, DataValidationException
23+
from ..tables import TableDataset
2424

2525

2626
class PaSchemaValidator(AggregateValidator):

cascade/utils/sklearn/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Copyright 2022-2023 Ilia Moiseev
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .sk_model import SkModel

cascade/utils/sk_model.py renamed to cascade/utils/sklearn/sk_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323
from sklearn.pipeline import Pipeline
2424

25-
from ..base import MetaHandler, PipeMeta
26-
from ..models import BasicModel
25+
from ...base import MetaHandler, PipeMeta
26+
from ...models import BasicModel
2727

2828

2929
class SkModel(BasicModel):

cascade/utils/tables/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Copyright 2022-2023 Ilia Moiseev
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .tables import TableDataset, TableFilter, TableIterator, PartedTableLoader, LargeCSVDataset, FeatureTable

cascade/utils/tables.py renamed to cascade/utils/tables/tables.py

+14-54
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,11 @@
1717
from typing import Any, Callable, List, Literal, Tuple, Union
1818

1919
import pandas as pd
20-
from dask import dataframe as dd
2120
from tqdm import tqdm
2221

23-
from cascade.base import PipeMeta
24-
25-
from ..base import PipeMeta
26-
from ..data import Dataset, Iterator, Modifier, SequentialCacher
27-
from ..meta import AggregateValidator, DataValidationException
22+
from ...base import PipeMeta
23+
from ...data import Dataset, Iterator, Modifier
24+
from ...meta import AggregateValidator, DataValidationException
2825

2926

3027
class TableDataset(Dataset):
@@ -132,33 +129,6 @@ def __init__(self, csv_file_path: str, *args: Any, **kwargs: Any) -> None:
132129
super().__init__(t=t, **kwargs)
133130

134131

135-
class PartedTableLoader(Dataset):
136-
"""
137-
Works like CSVDataset, but uses dask to load tables
138-
and returns partitions on `__getitem__`.
139-
140-
See also
141-
--------
142-
cascade.utils.CSVDataset
143-
"""
144-
145-
def __init__(self, csv_file_path: str, *args: Any, **kwargs: Any) -> None:
146-
super().__init__(**kwargs)
147-
self._table = dd.read_csv(csv_file_path, *args, **kwargs)
148-
149-
def __getitem__(self, index: int):
150-
"""
151-
Returns partition under the index.
152-
"""
153-
return self._table.get_partition(index).compute()
154-
155-
def __len__(self) -> int:
156-
"""
157-
Returns the number of partitions.
158-
"""
159-
return self._table.npartitions
160-
161-
162132
class TableIterator(Iterator):
163133
"""
164134
Iterates over the table from path by the chunks.
@@ -182,26 +152,6 @@ def __next__(self):
182152
return self._data.get_chunk(self.chunk_size)
183153

184154

185-
class LargeCSVDataset(SequentialCacher):
186-
"""
187-
SequentialCacher over large .csv file.
188-
Loads table by partitions.
189-
"""
190-
191-
def __init__(self, csv_file_path: str, *args: Any, **kwargs: Any) -> None:
192-
dataset = PartedTableLoader(csv_file_path, *args, **kwargs)
193-
self._ln = len(dataset._table)
194-
self.num_batches = dataset._table.npartitions
195-
self.bs = self._ln // self.num_batches
196-
super().__init__(dataset, self.bs)
197-
198-
def _load(self, index: int) -> None:
199-
self._batch = TableDataset(t=self._dataset[index])
200-
201-
def __len__(self) -> int:
202-
return self._ln
203-
204-
205155
class NullValidator(TableDataset, AggregateValidator):
206156
"""
207157
Checks that there are no null values in the table.
@@ -240,7 +190,7 @@ def __init__(
240190
```python
241191
>>> import pandas as pd
242192
>>> from cascade.utils.tables import FeatureTable
243-
>>> df = pd.read_csv(r'C:\cascade_integration\data\t.csv', index_col=0)
193+
>>> df = pd.read_csv(r'data\t.csv', index_col=0)
244194
>>> df
245195
id count name
246196
0 0 1 aaa
@@ -370,3 +320,13 @@ def get_meta(self) -> PipeMeta:
370320
for key in self._computed_features_kwargs
371321
}
372322
return meta
323+
324+
325+
class PartedTableLoader(TableDataset):
326+
def __init__(self, *args: Any, t = None, **kwargs: Any) -> None:
327+
raise ImportError("PartedTableLoader was removed since 0.12.0, consider using older version")
328+
329+
330+
class LargeCSVDataset(TableDataset):
331+
def __init__(self, *args: Any, t = None, **kwargs: Any) -> None:
332+
raise ImportError("LargeCSVDataset was removed since 0.12.0, consider using older version")

cascade/utils/tests/test_baselines.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ def test():
3939
model.predict([0, 0, 0])
4040
== np.array([[[1, 0], [0, 1]], [[1, 0], [0, 1]], [[1, 0], [0, 1]]])
4141
)
42+

cascade/utils/tests/test_folder_image_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
sys.path.append(os.path.dirname(MODULE_PATH))
2828

29-
from cascade.utils.folder_image_dataset import FolderImageDataset
29+
from cascade.utils.vision import FolderImageDataset
3030

3131

3232
@pytest.fixture

cascade/utils/tests/test_numpy_wrapper.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import sys
1919

20+
import pytest
2021
import numpy as np
2122

2223
MODULE_PATH = os.path.dirname(
@@ -27,6 +28,7 @@
2728
from cascade.utils.numpy_wrapper import NumpyWrapper
2829

2930

31+
@pytest.mark.skip
3032
def test(tmp_path):
3133
arr = np.array([1, 2, 3, 4, 5])
3234
path = os.path.join(tmp_path, "arr.npy")

cascade/utils/tests/test_sk_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
import cascade as csd
32-
from cascade.utils.sk_model import SkModel
32+
from cascade.utils.sklearn import SkModel
3333

3434

3535
@pytest.mark.parametrize("ext", [".json", ".yml"])

cascade/utils/tests/test_text_classification_dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
sys.path.append(os.path.dirname(MODULE_PATH))
2424

25-
from cascade.utils.text_classification_dataset import TextClassificationDataset
25+
from cascade.utils.nlp import TextClassificationFolder
2626

2727

2828
def test_create(tmp_path):
@@ -38,7 +38,7 @@ def test_create(tmp_path):
3838
with open(os.path.join(path, "text_2.txt"), "w") as f:
3939
f.write("hello")
4040

41-
ds = TextClassificationDataset(tmp_path)
41+
ds = TextClassificationFolder(tmp_path)
4242
meta = ds.get_meta()[0]
4343

4444
assert meta["size"] == 6

cascade/utils/tests/test_torch_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
sys.path.append(os.path.dirname(MODULE_PATH))
2727

2828

29-
from cascade.utils.torch_model import TorchModel
29+
from cascade.utils.torch import TorchModel
3030

3131

3232
@pytest.mark.parametrize("postfix", ["", "model", "model.pt"])

cascade/utils/time_series/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Copyright 2022-2023 Ilia Moiseev
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from .time_series_dataset import TimeSeriesDataset
18+
from .time_series import Average, Interpolate, Align

0 commit comments

Comments
 (0)