|
24 | 24 | from aeon.regression import BaseRegressor
|
25 | 25 | from aeon.regression.deep_learning.base import BaseDeepRegressor
|
26 | 26 | from aeon.segmentation import BaseSegmenter
|
27 |
| -from aeon.similarity_search import BaseSimilaritySearch |
28 | 27 | from aeon.testing.estimator_checking._yield_anomaly_detection_checks import (
|
29 | 28 | _yield_anomaly_detection_checks,
|
30 | 29 | )
|
|
34 | 33 | from aeon.testing.estimator_checking._yield_clustering_checks import (
|
35 | 34 | _yield_clustering_checks,
|
36 | 35 | )
|
37 |
| -from aeon.testing.estimator_checking._yield_collection_transformation_checks import ( |
38 |
| - _yield_collection_transformation_checks, |
39 |
| -) |
40 | 36 | from aeon.testing.estimator_checking._yield_early_classification_checks import (
|
41 | 37 | _yield_early_classification_checks,
|
42 | 38 | )
|
|
49 | 45 | from aeon.testing.estimator_checking._yield_segmentation_checks import (
|
50 | 46 | _yield_segmentation_checks,
|
51 | 47 | )
|
52 |
| -from aeon.testing.estimator_checking._yield_series_transformation_checks import ( |
53 |
| - _yield_series_transformation_checks, |
54 |
| -) |
55 |
| -from aeon.testing.estimator_checking._yield_similarity_search_checks import ( |
56 |
| - _yield_similarity_search_checks, |
57 |
| -) |
58 | 48 | from aeon.testing.estimator_checking._yield_soft_dependency_checks import (
|
59 | 49 | _yield_soft_dependency_checks,
|
60 | 50 | )
|
|
69 | 59 | from aeon.testing.utils.deep_equals import deep_equals
|
70 | 60 | from aeon.testing.utils.estimator_checks import _get_tag, _run_estimator_method
|
71 | 61 | from aeon.transformations.base import BaseTransformer
|
72 |
| -from aeon.transformations.collection import BaseCollectionTransformer |
73 |
| -from aeon.transformations.series import BaseSeriesTransformer |
74 | 62 | from aeon.utils.base import VALID_ESTIMATOR_BASES
|
75 | 63 | from aeon.utils.tags import check_valid_tags
|
76 | 64 | from aeon.utils.validation._dependencies import _check_estimator_deps
|
@@ -153,26 +141,11 @@ def _yield_all_aeon_checks(
|
153 | 141 | estimator_class, estimator_instances, datatypes
|
154 | 142 | )
|
155 | 143 |
|
156 |
| - if issubclass(estimator_class, BaseSimilaritySearch): |
157 |
| - yield from _yield_similarity_search_checks( |
158 |
| - estimator_class, estimator_instances, datatypes |
159 |
| - ) |
160 |
| - |
161 | 144 | if issubclass(estimator_class, BaseTransformer):
|
162 | 145 | yield from _yield_transformation_checks(
|
163 | 146 | estimator_class, estimator_instances, datatypes
|
164 | 147 | )
|
165 | 148 |
|
166 |
| - if issubclass(estimator_class, BaseCollectionTransformer): |
167 |
| - yield from _yield_collection_transformation_checks( |
168 |
| - estimator_class, estimator_instances, datatypes |
169 |
| - ) |
170 |
| - |
171 |
| - if issubclass(estimator_class, BaseSeriesTransformer): |
172 |
| - yield from _yield_series_transformation_checks( |
173 |
| - estimator_class, estimator_instances, datatypes |
174 |
| - ) |
175 |
| - |
176 | 149 |
|
177 | 150 | def _yield_estimator_checks(estimator_class, estimator_instances, datatypes):
|
178 | 151 | """Yield all general checks for an aeon estimator."""
|
@@ -289,6 +262,11 @@ def check_has_common_interface(estimator_class):
|
289 | 262 | "axis" not in estimator_class.__dict__
|
290 | 263 | ), "axis should not be a class parameter"
|
291 | 264 |
|
| 265 | + # Must have at least one set to True |
| 266 | + multi = estimator_class.get_class_tag(tag_name="capability:multivariate") |
| 267 | + uni = estimator_class.get_class_tag(tag_name="capability:univariate") |
| 268 | + assert multi or uni |
| 269 | + |
292 | 270 |
|
293 | 271 | def check_set_params(estimator_class):
|
294 | 272 | """Check that set_params works correctly."""
|
|
0 commit comments