@@ -19,25 +19,23 @@ class name: BaseCollectionTransformer
19
19
fitted state inspection - check_is_fitted()
20
20
"""
21
21
22
- __maintainer__ = []
22
+ __maintainer__ = ["MatthewMiddlehurst" ]
23
23
__all__ = [
24
24
"BaseCollectionTransformer" ,
25
25
]
26
26
27
27
from abc import abstractmethod
28
28
from typing import final
29
29
30
- import numpy as np
31
- import pandas as pd
32
-
33
30
from aeon .base import BaseCollectionEstimator
34
31
from aeon .transformations .base import BaseTransformer
32
+ from aeon .utils .validation import get_n_cases
35
33
36
34
37
35
class BaseCollectionTransformer (BaseCollectionEstimator , BaseTransformer ):
38
36
"""Transformer base class for collections."""
39
37
40
- # tag values specific to CollectionTransformers
38
+ # default tag values for collection transformers
41
39
_tags = {
42
40
"input_data_type" : "Collection" ,
43
41
"output_data_type" : "Collection" ,
@@ -84,22 +82,25 @@ def fit(self, X, y=None):
84
82
-------
85
83
self : a fitted instance of the estimator
86
84
"""
87
- if self .get_tag ("requires_y" ):
88
- if y is None :
89
- raise ValueError ("Tag requires_y is true, but fit called with y=None" )
90
- # skip the rest if fit_is_empty is True
91
85
if self .get_tag ("fit_is_empty" ):
92
86
self .is_fitted = True
93
87
return self
88
+
89
+ if self .get_tag ("requires_y" ):
90
+ if y is None :
91
+ raise ValueError ("Tag requires_y is true, but fit called with y=None" )
92
+
93
+ # reset estimator at the start of fit
94
94
self .reset ()
95
95
96
96
# input checks and datatype conversion
97
- X_inner = self ._preprocess_collection (X )
98
- y_inner = y
99
- self ._fit ( X = X_inner , y = y_inner )
97
+ X = self ._preprocess_collection (X , store_metadata = True )
98
+ if y is not None :
99
+ self ._check_y ( y , n_cases = self . metadata_ [ "n_cases" ] )
100
100
101
- self .is_fitted = True
101
+ self ._fit ( X = X , y = y )
102
102
103
+ self .is_fitted = True
103
104
return self
104
105
105
106
@final
@@ -139,18 +140,19 @@ def transform(self, X, y=None):
139
140
-------
140
141
transformed version of X
141
142
"""
142
- # check whether is fitted
143
- self ._check_is_fitted ()
143
+ fit_empty = self .get_tag ("fit_is_empty" )
144
+ if not fit_empty :
145
+ self ._check_is_fitted ()
144
146
145
- # input check and conversion for X/y
146
- X_inner = self ._preprocess_collection (X , store_metadata = False )
147
- y_inner = y
147
+ # input checks and datatype conversion
148
+ X = self ._preprocess_collection (X , store_metadata = False )
149
+ if y is not None :
150
+ self ._check_y (y , n_cases = get_n_cases (X ))
148
151
149
- if not self . get_tag ( "fit_is_empty" ) :
152
+ if not fit_empty :
150
153
self ._check_shape (X )
151
154
152
- Xt = self ._transform (X = X_inner , y = y_inner )
153
-
155
+ Xt = self ._transform (X , y )
154
156
return Xt
155
157
156
158
@final
@@ -192,14 +194,21 @@ def fit_transform(self, X, y=None):
192
194
-------
193
195
transformed version of X
194
196
"""
195
- # input checks and datatype conversion
197
+ if self .get_tag ("requires_y" ):
198
+ if y is None :
199
+ raise ValueError ("Tag requires_y is true, but fit called with y=None" )
200
+
201
+ # reset estimator at the start of fit
196
202
self .reset ()
197
- X_inner = self ._preprocess_collection (X )
198
- y_inner = y
199
- Xt = self ._fit_transform (X = X_inner , y = y_inner )
200
203
201
- self .is_fitted = True
204
+ # input checks and datatype conversion
205
+ X = self ._preprocess_collection (X , store_metadata = True )
206
+ if y is not None :
207
+ self ._check_y (y , n_cases = self .metadata_ ["n_cases" ])
202
208
209
+ Xt = self ._fit_transform (X = X , y = y )
210
+
211
+ self .is_fitted = True
203
212
return Xt
204
213
205
214
@final
@@ -297,6 +306,7 @@ def _transform(self, X, y=None):
297
306
-------
298
307
transformed version of X
299
308
"""
309
+ ...
300
310
301
311
def _fit_transform (self , X , y = None ):
302
312
"""Fit to data, then transform it.
@@ -341,41 +351,3 @@ def _inverse_transform(self, X, y=None):
341
351
raise NotImplementedError (
342
352
f"{ self .__class__ .__name__ } does not support inverse_transform"
343
353
)
344
-
345
- def _update (self , X , y = None ):
346
- """Update transformer with X and y.
347
-
348
- private _update containing the core logic, called from update
349
-
350
- Parameters
351
- ----------
352
- X : Input data
353
- Data to fit transform to, of valid collection type.
354
- y : Target variable, default=None
355
- Additional data, e.g., labels for transformation
356
-
357
- Returns
358
- -------
359
- self: a fitted instance of the estimator.
360
- """
361
- # standard behaviour: no update takes place, new data is ignored
362
- return self
363
-
364
-
365
- def _check_y (self , y , n_cases ):
366
- if y is None :
367
- return None
368
- # Check y valid input for collection transformations
369
- if not isinstance (y , (pd .Series , np .ndarray )):
370
- raise TypeError (
371
- f"y must be a np.array or a pd.Series, but found type: { type (y )} "
372
- )
373
- if isinstance (y , np .ndarray ) and y .ndim > 1 :
374
- raise TypeError (f"y must be 1-dimensional, found { y .ndim } dimensions" )
375
- # Check matching number of labels
376
- n_labels = y .shape [0 ]
377
- if n_cases != n_labels :
378
- raise ValueError (
379
- f"Mismatch in number of cases. Number in X = { n_cases } nos in y = "
380
- f"{ n_labels } "
381
- )
0 commit comments