Skip to content

[MNT] Testing fixes #2531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions aeon/testing/estimator_checking/_yield_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,9 @@ def check_persistence_via_pickle(estimator, datatype):
same, msg = deep_equals(output, results[i], return_msg=True)
if not same:
raise ValueError(
f"Running {method} after serialisation parameters gives "
f"different results. "
f"{type(estimator)} returns data as {type(output)}: test "
f"equivalence message: {msg}"
f"Running {type(estimator)} {method} with test parameters after "
f"serialisation gives different results. "
f"Check equivalence message: {msg}"
)
i += 1

Expand All @@ -657,9 +656,8 @@ def check_fit_deterministic(estimator, datatype):
same, msg = deep_equals(output, results[i], return_msg=True)
if not same:
raise ValueError(
f"Running {method} with test parameters after two calls to fit "
f"gives different results."
f"{type(estimator)} returns data as {type(output)}: test "
f"equivalence message: {msg}"
f"Running {type(estimator)} {method} with test parameters after "
f"two calls to fit gives different results."
f"Check equivalence message: {msg}"
)
i += 1
7 changes: 5 additions & 2 deletions aeon/testing/testing_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
NUMBA_DISABLED = os.environ.get("NUMBA_DISABLE_JIT") == "1"

# exclude estimators here for short term fixes
# Hydra excluded because it returns a pytorch Tensor
EXCLUDE_ESTIMATORS = ["REDCOMETS", "HydraTransformer"]
EXCLUDE_ESTIMATORS = [
"REDCOMETS",
"HydraTransformer", # returns a pytorch Tensor
]

# Exclude specific tests for estimators here
EXCLUDED_TESTS = {
Expand All @@ -50,6 +52,7 @@
"RSASTClassifier": ["check_fit_deterministic"],
"SAST": ["check_fit_deterministic"],
"RSAST": ["check_fit_deterministic"],
"MatrixProfile": ["check_persistence_via_pickle"],
# missed in legacy testing, changes state in predict/transform
"FLUSSSegmenter": ["check_non_state_changing_method"],
"InformationGainSegmenter": ["check_non_state_changing_method"],
Expand Down
14 changes: 10 additions & 4 deletions aeon/testing/utils/deep_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _deep_equals(x, y, depth, ignore_index):
elif isinstance(x, pd.DataFrame):
return _dataframe_equals(x, y, depth, ignore_index)
elif isinstance(x, np.ndarray):
return _numpy_equals(x, y, depth)
return _numpy_equals(x, y, depth, ignore_index)
elif isinstance(x, (list, tuple)):
return _list_equals(x, y, depth, ignore_index)
elif isinstance(x, dict):
Expand Down Expand Up @@ -128,15 +128,21 @@ def _dataframe_equals(x, y, depth, ignore_index):
return eq, msg


def _numpy_equals(x, y, depth):
def _numpy_equals(x, y, depth, ignore_index):
if x.dtype != y.dtype:
return False, f"x.dtype ({x.dtype}) != y.dtype ({y.dtype})"

if x.dtype == "object":
eq, msg = _deep_equals(x.tolist(), y.tolist(), depth, ignore_index=True)
for i in range(len(x)):
eq, msg = _deep_equals(x[i], y[i], depth + 1, ignore_index)

if not eq:
return False, msg + f", idx={i}"
else:
eq = np.allclose(x, y, equal_nan=True)
msg = "" if eq else f"x ({x}) != y ({y}), depth={depth}"
return eq, msg
return eq, msg
return True, ""


def _csrmatrix_equals(x, y, depth):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ addopts = '''
--dist worksteal
--reruns 2
--only-rerun "crashed while running"
--only-rerun "zipfile.BadZipFile"
'''
filterwarnings = '''
ignore::UserWarning
Expand Down