Skip to content

Commit 236d039

Browse files
[MNT] Raise version bound for scikit-learn 1.6 (aeon-toolkit#2486)
* update ver and new tags * default tags * toml * Update _shapelets.py Fix linear estimator coefs issue * expected results * Change expected results * update * only linux * remove mixins just to see test * revert --------- Co-authored-by: Antoine Guillaume <antoine.guillaume45@gmail.com>
1 parent 3106722 commit 236d039

File tree

14 files changed

+1924
-230
lines changed

14 files changed

+1924
-230
lines changed

aeon/base/_base.py

+12
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,18 @@ def __sklearn_is_fitted__(self):
415415
"""Check fitted status and return a Boolean value."""
416416
return self.is_fitted
417417

418+
def __sklearn_tags__(self):
419+
"""Return sklearn style tags for the estimator."""
420+
aeon_tags = self.get_tags()
421+
sklearn_tags = super().__sklearn_tags__()
422+
sklearn_tags.non_deterministic = aeon_tags.get("non_deterministic", False)
423+
sklearn_tags.target_tags.one_d_labels = True
424+
sklearn_tags.input_tags.three_d_array = True
425+
sklearn_tags.input_tags.allow_nan = aeon_tags.get(
426+
"capability:missing_values", False
427+
)
428+
return sklearn_tags
429+
418430
def _validate_data(self, **kwargs):
419431
"""Sklearn data validation."""
420432
raise NotImplementedError(

aeon/classification/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class name: BaseClassifier
2626

2727
import numpy as np
2828
import pandas as pd
29+
from sklearn.base import ClassifierMixin
2930
from sklearn.metrics import get_scorer, get_scorer_names
3031
from sklearn.model_selection import cross_val_predict
3132

@@ -35,7 +36,7 @@ class name: BaseClassifier
3536
from aeon.utils.validation.labels import check_classification_y
3637

3738

38-
class BaseClassifier(BaseCollectionEstimator):
39+
class BaseClassifier(ClassifierMixin, BaseCollectionEstimator):
3940
"""
4041
Abstract base class for time series classifiers.
4142
@@ -66,7 +67,6 @@ def __init__(self):
6667
self.classes_ = [] # classes seen in y, unique labels
6768
self.n_classes_ = -1 # number of unique classes in y
6869
self._class_dictionary = {}
69-
self._estimator_type = "classifier"
7070

7171
super().__init__()
7272

aeon/clustering/base.py

+2-23
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
from typing import final
88

99
import numpy as np
10+
from sklearn.base import ClusterMixin
1011

1112
from aeon.base import BaseCollectionEstimator
1213

1314

14-
class BaseClusterer(BaseCollectionEstimator):
15+
class BaseClusterer(ClusterMixin, BaseCollectionEstimator):
1516
"""Abstract base class for time series clusterers.
1617
1718
Parameters
@@ -26,10 +27,6 @@ class BaseClusterer(BaseCollectionEstimator):
2627

2728
@abstractmethod
2829
def __init__(self):
29-
# required for compatibility with some sklearn interfaces e.g.
30-
# CalibratedClassifierCV
31-
self._estimator_type = "clusterer"
32-
3330
super().__init__()
3431

3532
@final
@@ -132,24 +129,6 @@ def fit_predict(self, X, y=None) -> np.ndarray:
132129
to return.
133130
y: ignored, exists for API consistency reasons.
134131
135-
Returns
136-
-------
137-
np.ndarray (1d array of shape (n_cases,))
138-
Index of the cluster each time series in X belongs to.
139-
"""
140-
return self._fit_predict(X, y)
141-
142-
def _fit_predict(self, X, y=None) -> np.ndarray:
143-
"""Fit predict using base methods.
144-
145-
Parameters
146-
----------
147-
X : np.ndarray (2d or 3d array of shape (n_cases, n_timepoints) or shape
148-
(n_cases, n_channels, n_timepoints)).
149-
Time series instances to train clusterer and then have indexes each belong
150-
to return.
151-
y: ignored, exists for API consistency reasons.
152-
153132
Returns
154133
-------
155134
np.ndarray (1d array of shape (n_cases,))

aeon/regression/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class name: BaseRegressor
2525

2626
import numpy as np
2727
import pandas as pd
28+
from sklearn.base import RegressorMixin
2829
from sklearn.metrics import get_scorer, get_scorer_names
2930
from sklearn.model_selection import cross_val_predict
3031
from sklearn.utils.multiclass import type_of_target
@@ -33,7 +34,7 @@ class name: BaseRegressor
3334
from aeon.base._base import _clone_estimator
3435

3536

36-
class BaseRegressor(BaseCollectionEstimator):
37+
class BaseRegressor(RegressorMixin, BaseCollectionEstimator):
3738
"""Abstract base class for time series regressors.
3839
3940
The base regressor specifies the methods and method signatures that all
@@ -54,9 +55,6 @@ class BaseRegressor(BaseCollectionEstimator):
5455

5556
@abstractmethod
5657
def __init__(self):
57-
# required for compatibility with some sklearn interfaces
58-
self._estimator_type = "regressor"
59-
6058
super().__init__()
6159

6260
@final

aeon/regression/feature_based/_catch22.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ class Catch22Regressor(BaseRegressor):
9494
>>> reg.fit(X, y)
9595
Catch22Regressor(...)
9696
>>> reg.predict(X)
97-
array([0.63821896, 1.0906666 , 0.58323551, 1.57550709, 0.48413489,
98-
0.70976176, 1.33206165, 1.09927538, 1.51673405, 0.31683308])
97+
array([0.63821896, 1.0906666 , 0.64351536, 1.57550709, 0.46036267,
98+
0.79297397, 1.32882497, 1.12603087, 1.51673405, 0.31683308])
9999
"""
100100

101101
_tags = {

aeon/regression/sklearn/tests/test_rotation_forest_regressor.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,21 @@ def test_rotf_output():
2424
rotf.fit(X_train, y_train)
2525

2626
expected = [
27-
0.02694297,
28-
0.02694297,
29-
0.01997832,
30-
0.04276962,
31-
0.09027588,
32-
0.02706564,
33-
0.02553648,
34-
0.04075808,
35-
0.02900289,
36-
0.04248546,
37-
0.02694297,
38-
0.03667328,
39-
0.0235855,
40-
0.03444119,
41-
0.0235855,
27+
0.026,
28+
0.0245,
29+
0.0224,
30+
0.0453,
31+
0.0892,
32+
0.0314,
33+
0.026,
34+
0.0451,
35+
0.0287,
36+
0.04,
37+
0.026,
38+
0.0378,
39+
0.0265,
40+
0.0356,
41+
0.0281,
4242
]
4343

4444
np.testing.assert_array_almost_equal(expected, rotf.predict(X_test[:15]), decimal=4)

aeon/testing/estimator_checking/_yield_classification_checks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
def _yield_classification_checks(estimator_class, estimator_instances, datatypes):
3232
"""Yield all classification checks for an aeon classifier."""
3333
# only class required
34-
if sys.platform != "darwin": # We cannot guarantee same results on ARM macOS
34+
if sys.platform == "linux": # We cannot guarantee same results on ARM macOS
3535
# Compare against results for both UnitTest and BasicMotions if available
3636
yield partial(
3737
check_classifier_against_expected_results,

aeon/testing/estimator_checking/_yield_regression_checks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
def _yield_regression_checks(estimator_class, estimator_instances, datatypes):
2727
"""Yield all regression checks for an aeon regressor."""
2828
# only class required
29-
if sys.platform != "darwin": # We cannot guarantee same results on ARM macOS
29+
if sys.platform == "linux": # We cannot guarantee same results on ARM macOS
3030
# Compare against results for both Covid3Month and CardanoSentiment if available
3131
yield partial(
3232
check_regressor_against_expected_results,

aeon/testing/estimator_checking/_yield_transformation_checks.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
def _yield_transformation_checks(estimator_class, estimator_instances, datatypes):
2727
"""Yield all transformation checks for an aeon transformer."""
2828
# only class required
29-
if sys.platform != "darwin":
29+
if sys.platform == "linux": # We cannot guarantee same results on ARM macOS
30+
# Compare against results for both UnitTest and BasicMotions if available
3031
yield partial(
3132
check_transformer_against_expected_results,
3233
estimator_class=estimator_class,

aeon/testing/expected_results/expected_classifier_outputs.py

+30-30
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,16 @@
6767
)
6868
unit_test_proba["TemporalDictionaryEnsemble"] = np.array(
6969
[
70-
[0.2778, 0.7222],
71-
[0.7222, 0.2778],
70+
[0.3307, 0.6693],
71+
[0.6693, 0.3307],
7272
[0.0, 1.0],
73-
[0.6251, 0.3749],
74-
[0.3749, 0.6251],
73+
[0.5538, 0.4462],
74+
[0.6693, 0.3307],
7575
[1.0, 0.0],
76-
[0.3749, 0.6251],
76+
[0.4462, 0.5538],
7777
[0.0, 1.0],
78-
[0.4653, 0.5347],
79-
[0.3749, 0.6251],
78+
[0.5538, 0.4462],
79+
[0.4462, 0.5538],
8080
]
8181
)
8282
unit_test_proba["WEASEL"] = np.array(
@@ -263,16 +263,16 @@
263263
)
264264
unit_test_proba["HIVECOTEV2"] = np.array(
265265
[
266-
[0.0613, 0.9387],
267-
[0.5531, 0.4479],
268-
[0.0431, 0.9569],
266+
[0.2239, 0.7761],
267+
[0.6732, 0.3268],
268+
[0.1211, 0.8789],
269269
[1.0, 0.0],
270-
[0.9751, 0.0249],
270+
[0.9818, 0.0182],
271271
[1.0, 0.0],
272-
[0.7398, 0.2602],
273-
[0.0365, 0.9635],
274-
[0.7829, 0.2171],
275-
[0.9236, 0.0764],
272+
[0.7201, 0.2799],
273+
[0.2058, 0.7942],
274+
[0.8412, 0.1588],
275+
[0.9441, 0.0559],
276276
]
277277
)
278278
unit_test_proba["CanonicalIntervalForestClassifier"] = np.array(
@@ -293,12 +293,12 @@
293293
[
294294
[0.1, 0.9],
295295
[0.8, 0.2],
296-
[0.0, 1.0],
296+
[0.1, 0.9],
297297
[1.0, 0.0],
298298
[0.7, 0.3],
299299
[0.9, 0.1],
300300
[0.8, 0.2],
301-
[0.4, 0.6],
301+
[0.5, 0.5],
302302
[0.9, 0.1],
303303
[1.0, 0.0],
304304
]
@@ -379,11 +379,11 @@
379379
[0.3505, 0.6495],
380380
[0.1753, 0.8247],
381381
[0.8247, 0.1753],
382-
[0.3505, 0.6495],
382+
[0.6495, 0.3505],
383383
[0.701, 0.299],
384384
[0.6495, 0.3505],
385385
[0.1753, 0.8247],
386-
[0.5258, 0.4742],
386+
[0.8247, 0.1753],
387387
[1.0, 0.0],
388388
]
389389
)
@@ -656,12 +656,12 @@
656656
)
657657
basic_motions_proba["FreshPRINCEClassifier"] = np.array(
658658
[
659-
[0.0, 0.0, 0.1, 0.9],
659+
[0.0, 0.0, 0.2, 0.8],
660660
[0.9, 0.1, 0.0, 0.0],
661661
[0.0, 0.0, 0.8, 0.2],
662662
[0.1, 0.9, 0.0, 0.0],
663-
[0.1, 0.0, 0.0, 0.9],
664-
[0.0, 0.0, 0.1, 0.9],
663+
[0.1, 0.0, 0.1, 0.8],
664+
[0.0, 0.0, 0.2, 0.8],
665665
[0.7, 0.3, 0.0, 0.0],
666666
[0.0, 0.0, 1.0, 0.0],
667667
[0.0, 1.0, 0.0, 0.0],
@@ -782,15 +782,15 @@
782782
)
783783
basic_motions_proba["DrCIFClassifier"] = np.array(
784784
[
785+
[0.1, 0.0, 0.2, 0.7],
786+
[0.5, 0.4, 0.0, 0.1],
787+
[0.0, 0.0, 0.8, 0.2],
788+
[0.1, 0.9, 0.0, 0.0],
789+
[0.1, 0.0, 0.3, 0.6],
785790
[0.0, 0.0, 0.2, 0.8],
786-
[0.4, 0.5, 0.1, 0.0],
787-
[0.0, 0.0, 0.7, 0.3],
788-
[0.2, 0.8, 0.0, 0.0],
789-
[0.0, 0.0, 0.3, 0.7],
790-
[0.0, 0.0, 0.3, 0.7],
791-
[0.7, 0.2, 0.1, 0.0],
792-
[0.0, 0.0, 0.7, 0.3],
793-
[0.1, 0.7, 0.1, 0.1],
791+
[0.5, 0.3, 0.0, 0.2],
792+
[0.0, 0.0, 0.8, 0.2],
793+
[0.2, 0.7, 0.0, 0.1],
794794
[0.0, 0.9, 0.0, 0.1],
795795
]
796796
)

0 commit comments

Comments
 (0)