Skip to content

Commit 05097f5

Browse files
[MNT] Add/rework transformation tests and remove from exclude list (#2360)
* excluded tests * trying to fix things * tidy up transform testing * fixes * fixes * fixes * still trying to make this work * ignore index for pandas * allclose * Empty commit for CI * correct * rist * rist * fix --------- Co-authored-by: MatthewMiddlehurst <MatthewMiddlehurst@users.noreply.github.com>
1 parent 224cdc1 commit 05097f5

11 files changed

+239
-239
lines changed

aeon/testing/estimator_checking/_estimator_checking.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from aeon.testing.estimator_checking._yield_estimator_checks import (
1919
_yield_all_aeon_checks,
2020
)
21-
from aeon.testing.testing_config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS
21+
from aeon.testing.testing_config import (
22+
EXCLUDE_ESTIMATORS,
23+
EXCLUDED_TESTS,
24+
EXCLUDED_TESTS_NO_NUMBA,
25+
NUMBA_DISABLED,
26+
)
2227
from aeon.utils.validation._dependencies import (
2328
_check_estimator_deps,
2429
_check_soft_dependencies,
@@ -313,6 +318,8 @@ def _should_be_skipped(estimator, check, has_dependencies):
313318
return True, "In aeon estimator exclude list", check_name
314319
elif check_name in EXCLUDED_TESTS.get(est_name, []):
315320
return True, "In aeon test exclude list for estimator", check_name
321+
elif NUMBA_DISABLED and check_name in EXCLUDED_TESTS_NO_NUMBA.get(est_name, []):
322+
return True, "In aeon no numba test exclude list for estimator", check_name
316323

317324
return False, "", check_name
318325

aeon/testing/estimator_checking/_yield_classification_checks.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,14 @@ def _yield_classification_checks(estimator_class, estimator_instances, datatypes
4141
results_dict=unit_test_proba,
4242
resample_seed=0,
4343
)
44-
# the test currently fails when numba is disabled. See issue #622
45-
if (
46-
estimator_class.__name__ != "HIVECOTEV2"
47-
or os.environ.get("NUMBA_DISABLE_JIT") != "1"
48-
):
49-
yield partial(
50-
check_classifier_against_expected_results,
51-
estimator_class=estimator_class,
52-
data_name="BasicMotions",
53-
data_loader=load_basic_motions,
54-
results_dict=basic_motions_proba,
55-
resample_seed=4,
56-
)
44+
yield partial(
45+
check_classifier_against_expected_results,
46+
estimator_class=estimator_class,
47+
data_name="BasicMotions",
48+
data_loader=load_basic_motions,
49+
results_dict=basic_motions_proba,
50+
resample_seed=4,
51+
)
5752
yield partial(check_classifier_overrides_and_tags, estimator_class=estimator_class)
5853

5954
# data type irrelevant

aeon/testing/estimator_checking/_yield_collection_transformation_checks.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

aeon/testing/estimator_checking/_yield_estimator_checks.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from aeon.regression import BaseRegressor
2525
from aeon.regression.deep_learning.base import BaseDeepRegressor
2626
from aeon.segmentation import BaseSegmenter
27-
from aeon.similarity_search import BaseSimilaritySearch
2827
from aeon.testing.estimator_checking._yield_anomaly_detection_checks import (
2928
_yield_anomaly_detection_checks,
3029
)
@@ -34,9 +33,6 @@
3433
from aeon.testing.estimator_checking._yield_clustering_checks import (
3534
_yield_clustering_checks,
3635
)
37-
from aeon.testing.estimator_checking._yield_collection_transformation_checks import (
38-
_yield_collection_transformation_checks,
39-
)
4036
from aeon.testing.estimator_checking._yield_early_classification_checks import (
4137
_yield_early_classification_checks,
4238
)
@@ -49,12 +45,6 @@
4945
from aeon.testing.estimator_checking._yield_segmentation_checks import (
5046
_yield_segmentation_checks,
5147
)
52-
from aeon.testing.estimator_checking._yield_series_transformation_checks import (
53-
_yield_series_transformation_checks,
54-
)
55-
from aeon.testing.estimator_checking._yield_similarity_search_checks import (
56-
_yield_similarity_search_checks,
57-
)
5848
from aeon.testing.estimator_checking._yield_soft_dependency_checks import (
5949
_yield_soft_dependency_checks,
6050
)
@@ -69,8 +59,6 @@
6959
from aeon.testing.utils.deep_equals import deep_equals
7060
from aeon.testing.utils.estimator_checks import _get_tag, _run_estimator_method
7161
from aeon.transformations.base import BaseTransformer
72-
from aeon.transformations.collection import BaseCollectionTransformer
73-
from aeon.transformations.series import BaseSeriesTransformer
7462
from aeon.utils.base import VALID_ESTIMATOR_BASES
7563
from aeon.utils.tags import check_valid_tags
7664
from aeon.utils.validation._dependencies import _check_estimator_deps
@@ -153,26 +141,11 @@ def _yield_all_aeon_checks(
153141
estimator_class, estimator_instances, datatypes
154142
)
155143

156-
if issubclass(estimator_class, BaseSimilaritySearch):
157-
yield from _yield_similarity_search_checks(
158-
estimator_class, estimator_instances, datatypes
159-
)
160-
161144
if issubclass(estimator_class, BaseTransformer):
162145
yield from _yield_transformation_checks(
163146
estimator_class, estimator_instances, datatypes
164147
)
165148

166-
if issubclass(estimator_class, BaseCollectionTransformer):
167-
yield from _yield_collection_transformation_checks(
168-
estimator_class, estimator_instances, datatypes
169-
)
170-
171-
if issubclass(estimator_class, BaseSeriesTransformer):
172-
yield from _yield_series_transformation_checks(
173-
estimator_class, estimator_instances, datatypes
174-
)
175-
176149

177150
def _yield_estimator_checks(estimator_class, estimator_instances, datatypes):
178151
"""Yield all general checks for an aeon estimator."""
@@ -289,6 +262,11 @@ def check_has_common_interface(estimator_class):
289262
"axis" not in estimator_class.__dict__
290263
), "axis should not be a class parameter"
291264

265+
# Must have at least one set to True
266+
multi = estimator_class.get_class_tag(tag_name="capability:multivariate")
267+
uni = estimator_class.get_class_tag(tag_name="capability:univariate")
268+
assert multi or uni
269+
292270

293271
def check_set_params(estimator_class):
294272
"""Check that set_params works correctly."""

aeon/testing/estimator_checking/_yield_series_transformation_checks.py

Lines changed: 0 additions & 9 deletions
This file was deleted.

aeon/testing/estimator_checking/_yield_similarity_search_checks.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)