Skip to content

Commit ff0cb4f

Browse files
MatthewMiddlehurstchrisholder
authored andcommitted
base transform tidy (#2773)
1 parent 8a151e4 commit ff0cb4f

File tree

3 files changed

+89
-82
lines changed

3 files changed

+89
-82
lines changed

aeon/transformations/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from abc import abstractmethod
77

8+
import numpy as np
9+
import pandas as pd
10+
811
from aeon.base import BaseAeonEstimator
912

1013

@@ -90,3 +93,22 @@ def fit_transform(self, X, y=None):
9093
Additional data, e.g., labels for transformation.
9194
"""
9295
...
96+
97+
def _check_y(self, y, n_cases=None):
98+
# Check y valid input for supervised transform
99+
if not isinstance(y, (pd.Series, np.ndarray)):
100+
raise TypeError(
101+
f"y must be a np.array or a pd.Series, but found type: {type(y)}"
102+
)
103+
104+
if isinstance(y, np.ndarray) and y.ndim > 1:
105+
raise TypeError(f"y must be 1-dimensional, found {y.ndim} dimensions")
106+
107+
if n_cases is not None:
108+
# Check matching number of labels
109+
n_labels = y.shape[0]
110+
if n_cases != n_labels:
111+
raise ValueError(
112+
f"Mismatch in number of cases. Number in X = {n_cases} nos in y = "
113+
f"{n_labels}"
114+
)

aeon/transformations/collection/base.py

Lines changed: 36 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,23 @@ class name: BaseCollectionTransformer
1919
fitted state inspection - check_is_fitted()
2020
"""
2121

22-
__maintainer__ = []
22+
__maintainer__ = ["MatthewMiddlehurst"]
2323
__all__ = [
2424
"BaseCollectionTransformer",
2525
]
2626

2727
from abc import abstractmethod
2828
from typing import final
2929

30-
import numpy as np
31-
import pandas as pd
32-
3330
from aeon.base import BaseCollectionEstimator
3431
from aeon.transformations.base import BaseTransformer
32+
from aeon.utils.validation import get_n_cases
3533

3634

3735
class BaseCollectionTransformer(BaseCollectionEstimator, BaseTransformer):
3836
"""Transformer base class for collections."""
3937

40-
# tag values specific to CollectionTransformers
38+
# default tag values for collection transformers
4139
_tags = {
4240
"input_data_type": "Collection",
4341
"output_data_type": "Collection",
@@ -84,22 +82,25 @@ def fit(self, X, y=None):
8482
-------
8583
self : a fitted instance of the estimator
8684
"""
87-
if self.get_tag("requires_y"):
88-
if y is None:
89-
raise ValueError("Tag requires_y is true, but fit called with y=None")
90-
# skip the rest if fit_is_empty is True
9185
if self.get_tag("fit_is_empty"):
9286
self.is_fitted = True
9387
return self
88+
89+
if self.get_tag("requires_y"):
90+
if y is None:
91+
raise ValueError("Tag requires_y is true, but fit called with y=None")
92+
93+
# reset estimator at the start of fit
9494
self.reset()
9595

9696
# input checks and datatype conversion
97-
X_inner = self._preprocess_collection(X)
98-
y_inner = y
99-
self._fit(X=X_inner, y=y_inner)
97+
X = self._preprocess_collection(X, store_metadata=True)
98+
if y is not None:
99+
self._check_y(y, n_cases=self.metadata_["n_cases"])
100100

101-
self.is_fitted = True
101+
self._fit(X=X, y=y)
102102

103+
self.is_fitted = True
103104
return self
104105

105106
@final
@@ -139,18 +140,19 @@ def transform(self, X, y=None):
139140
-------
140141
transformed version of X
141142
"""
142-
# check whether is fitted
143-
self._check_is_fitted()
143+
fit_empty = self.get_tag("fit_is_empty")
144+
if not fit_empty:
145+
self._check_is_fitted()
144146

145-
# input check and conversion for X/y
146-
X_inner = self._preprocess_collection(X, store_metadata=False)
147-
y_inner = y
147+
# input checks and datatype conversion
148+
X = self._preprocess_collection(X, store_metadata=False)
149+
if y is not None:
150+
self._check_y(y, n_cases=get_n_cases(X))
148151

149-
if not self.get_tag("fit_is_empty"):
152+
if not fit_empty:
150153
self._check_shape(X)
151154

152-
Xt = self._transform(X=X_inner, y=y_inner)
153-
155+
Xt = self._transform(X, y)
154156
return Xt
155157

156158
@final
@@ -192,14 +194,21 @@ def fit_transform(self, X, y=None):
192194
-------
193195
transformed version of X
194196
"""
195-
# input checks and datatype conversion
197+
if self.get_tag("requires_y"):
198+
if y is None:
199+
raise ValueError("Tag requires_y is true, but fit called with y=None")
200+
201+
# reset estimator at the start of fit
196202
self.reset()
197-
X_inner = self._preprocess_collection(X)
198-
y_inner = y
199-
Xt = self._fit_transform(X=X_inner, y=y_inner)
200203

201-
self.is_fitted = True
204+
# input checks and datatype conversion
205+
X = self._preprocess_collection(X, store_metadata=True)
206+
if y is not None:
207+
self._check_y(y, n_cases=self.metadata_["n_cases"])
202208

209+
Xt = self._fit_transform(X=X, y=y)
210+
211+
self.is_fitted = True
203212
return Xt
204213

205214
@final
@@ -297,6 +306,7 @@ def _transform(self, X, y=None):
297306
-------
298307
transformed version of X
299308
"""
309+
...
300310

301311
def _fit_transform(self, X, y=None):
302312
"""Fit to data, then transform it.
@@ -341,41 +351,3 @@ def _inverse_transform(self, X, y=None):
341351
raise NotImplementedError(
342352
f"{self.__class__.__name__} does not support inverse_transform"
343353
)
344-
345-
def _update(self, X, y=None):
346-
"""Update transformer with X and y.
347-
348-
private _update containing the core logic, called from update
349-
350-
Parameters
351-
----------
352-
X : Input data
353-
Data to fit transform to, of valid collection type.
354-
y : Target variable, default=None
355-
Additional data, e.g., labels for transformation
356-
357-
Returns
358-
-------
359-
self: a fitted instance of the estimator.
360-
"""
361-
# standard behaviour: no update takes place, new data is ignored
362-
return self
363-
364-
365-
def _check_y(self, y, n_cases):
366-
if y is None:
367-
return None
368-
# Check y valid input for collection transformations
369-
if not isinstance(y, (pd.Series, np.ndarray)):
370-
raise TypeError(
371-
f"y must be a np.array or a pd.Series, but found type: {type(y)}"
372-
)
373-
if isinstance(y, np.ndarray) and y.ndim > 1:
374-
raise TypeError(f"y must be 1-dimensional, found {y.ndim} dimensions")
375-
# Check matching number of labels
376-
n_labels = y.shape[0]
377-
if n_cases != n_labels:
378-
raise ValueError(
379-
f"Mismatch in number of cases. Number in X = {n_cases} nos in y = "
380-
f"{n_labels}"
381-
)

aeon/transformations/series/base.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,14 @@ class name: BaseSeriesTransformer
1111
from abc import abstractmethod
1212
from typing import final
1313

14-
import numpy as np
15-
import pandas as pd
16-
1714
from aeon.base import BaseSeriesEstimator
1815
from aeon.transformations.base import BaseTransformer
1916

2017

2118
class BaseSeriesTransformer(BaseSeriesEstimator, BaseTransformer):
2219
"""Transformer base class for collections."""
2320

24-
# tag values specific to SeriesTransformers
21+
# default tag values for series transformers
2522
_tags = {
2623
"input_data_type": "Series",
2724
"output_data_type": "Series",
@@ -58,19 +55,24 @@ def fit(self, X, y=None, axis=1):
5855
-------
5956
self : a fitted instance of the estimator
6057
"""
61-
# skip the rest if fit_is_empty is True
6258
if self.get_tag("fit_is_empty"):
6359
self.is_fitted = True
6460
return self
61+
6562
if self.get_tag("requires_y"):
6663
if y is None:
6764
raise ValueError("Tag requires_y is true, but fit called with y=None")
65+
6866
# reset estimator at the start of fit
6967
self.reset()
68+
69+
# input checks and datatype conversion
7070
X = self._preprocess_series(X, axis=axis, store_metadata=True)
7171
if y is not None:
7272
self._check_y(y)
73+
7374
self._fit(X=X, y=y)
75+
7476
self.is_fitted = True
7577
return self
7678

@@ -101,9 +103,18 @@ def transform(self, X, y=None, axis=1):
101103
transformed version of X with the same axis as passed by the user, if axis
102104
not None.
103105
"""
104-
# check whether is fitted
105-
self._check_is_fitted()
106+
fit_empty = self.get_tag("fit_is_empty")
107+
if not fit_empty:
108+
self._check_is_fitted()
109+
106110
X = self._preprocess_series(X, axis=axis, store_metadata=False)
111+
if y is not None:
112+
self._check_y(y)
113+
114+
# #2768
115+
# if not fit_empty:
116+
# self._check_shape(X)
117+
107118
Xt = self._transform(X, y)
108119
return self._postprocess_series(Xt, axis=axis)
109120

@@ -137,10 +148,20 @@ def fit_transform(self, X, y=None, axis=1):
137148
transformed version of X with the same axis as passed by the user, if axis
138149
not None.
139150
"""
140-
# input checks and datatype conversion, to avoid doing in both fit and transform
151+
if self.get_tag("requires_y"):
152+
if y is None:
153+
raise ValueError("Tag requires_y is true, but fit called with y=None")
154+
155+
# reset estimator at the start of fit
141156
self.reset()
157+
158+
# input checks and datatype conversion
142159
X = self._preprocess_series(X, axis=axis, store_metadata=True)
160+
if y is not None:
161+
self._check_y(y)
162+
143163
Xt = self._fit_transform(X=X, y=y)
164+
144165
self.is_fitted = True
145166
return self._postprocess_series(Xt, axis=axis)
146167

@@ -263,7 +284,8 @@ def _fit_transform(self, X, y=None):
263284
"""
264285
# Non-optimized default implementation; override when a better
265286
# method is possible for a given algorithm.
266-
return self._fit(X, y)._transform(X, y)
287+
self._fit(X, y)
288+
return self._transform(X, y)
267289

268290
def _inverse_transform(self, X, y=None):
269291
"""Inverse transform X and return an inverse transformed version.
@@ -325,12 +347,3 @@ def _postprocess_series(self, Xt, axis):
325347
return Xt
326348
else:
327349
return Xt.T
328-
329-
def _check_y(self, y):
330-
# Check y valid input for supervised transform
331-
if not isinstance(y, (pd.Series, np.ndarray)):
332-
raise TypeError(
333-
f"y must be a np.array or a pd.Series, but found type: {type(y)}"
334-
)
335-
if isinstance(y, np.ndarray) and y.ndim > 1:
336-
raise TypeError(f"y must be 1-dimensional, found {y.ndim} dimensions")

0 commit comments

Comments
 (0)