Skip to content

Commit 3db87e5

Browse files
committed
some further cleanup
1 parent 7ff247f commit 3db87e5

File tree

10 files changed

+110
-61
lines changed

10 files changed

+110
-61
lines changed

mlscorecheck/check/binary/_check_1_dataset_unknown_folds_mos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def check_1_dataset_unknown_folds_mos(
4545
folding: dict,
4646
scores: dict,
4747
eps,
48-
fold_score_bounds: dict | None = None,
4948
*,
49+
score_bounds: dict | None = None,
5050
solver_name: str | None = None,
5151
timeout: int | None = None,
5252
verbosity: int = 1,

mlscorecheck/check/binary/_check_n_datasets_mos_unknown_folds_mos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def check_n_datasets_mos_unknown_folds_mos(
4848
evaluations: list,
4949
scores: dict,
5050
eps,
51-
dataset_score_bounds: dict | None = None,
5251
*,
52+
score_bounds: dict | None = None,
5353
solver_name: str | None = None,
5454
timeout: int | None = None,
5555
verbosity: int = 1,

mlscorecheck/check/bundles/retina/_chasedb1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
]
2020

2121

22-
def _filter_chasedb1(data, imageset, annotator):
22+
def _filter_chasedb1(data: dict, imageset, annotator: str) -> list:
2323
"""
2424
Filters the CHASEDB1 dataset
2525
@@ -107,8 +107,8 @@ def check_chasedb1_vessel_aggregated_mos(
107107

108108

109109
def check_chasedb1_vessel_aggregated_som(
110-
imageset, annotator, scores, eps, numerical_tolerance=NUMERICAL_TOLERANCE
111-
):
110+
imageset, annotator: str, scores: dict, eps, numerical_tolerance=NUMERICAL_TOLERANCE
111+
) -> dict:
112112
"""
113113
Tests the consistency of scores calculated on the CHASEDB1 dataset using
114114
the score-of-means aggregation.
@@ -252,7 +252,7 @@ def check_chasedb1_vessel_image(
252252
eps,
253253
*,
254254
numerical_tolerance: float = NUMERICAL_TOLERANCE,
255-
):
255+
) -> dict:
256256
"""
257257
Testing the scores calculated for one image of the CHASEDB1 dataset
258258

mlscorecheck/check/bundles/retina/_hrf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
]
2222

2323

24-
def _filter_hrf(data, imageset, assumption):
24+
def _filter_hrf(data: dict, imageset, assumption: str) -> list:
2525
"""
2626
Filters the HRF dataset
2727
@@ -118,7 +118,7 @@ def check_hrf_vessel_aggregated_mos_assumption(
118118

119119
def check_hrf_vessel_aggregated_som_assumption(
120120
imageset, assumption: str, scores: dict, eps, numerical_tolerance=NUMERICAL_TOLERANCE
121-
):
121+
) -> dict:
122122
"""
123123
Tests the consistency of scores calculated on the HRF dataset using
124124
the score-of-means aggregation and an assumption on the region of evaluation.
@@ -183,7 +183,7 @@ def check_hrf_vessel_image_assumption(
183183
eps,
184184
*,
185185
numerical_tolerance: float = NUMERICAL_TOLERANCE,
186-
):
186+
) -> dict:
187187
"""
188188
Testing the scores calculated for one image of the HRF dataset using an
189189
assumption on the region of evaluation.
@@ -242,6 +242,7 @@ def check_hrf_vessel_image_assumption(
242242

243243
def check_hrf_vessel_aggregated(
244244
imageset,
245+
assumption: str,
245246
scores: dict,
246247
eps,
247248
*,
@@ -325,7 +326,7 @@ def check_hrf_vessel_aggregated(
325326

326327
def check_hrf_vessel_image(
327328
image_identifier: str, scores: dict, eps, *, numerical_tolerance: float = NUMERICAL_TOLERANCE
328-
):
329+
) -> dict:
329330
"""
330331
Testing the scores calculated for one image of the HRF dataset with
331332
both assumptions on the region of evaluation ('fov'/'all')

mlscorecheck/check/bundles/retina/_stare.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
]
2020

2121

22-
def _filter_stare(data, imageset, annotator):
22+
def _filter_stare(data: dict, imageset, annotator: str) -> list:
2323
"""
2424
Filters the STARE dataset
2525
@@ -111,8 +111,8 @@ def check_stare_vessel_aggregated_mos(
111111

112112

113113
def check_stare_vessel_aggregated_som(
114-
imageset, annotator, scores, eps, numerical_tolerance=NUMERICAL_TOLERANCE
115-
):
114+
imageset, annotator: str, scores: dict, eps, numerical_tolerance=NUMERICAL_TOLERANCE
115+
) -> dict:
116116
"""
117117
Tests the consistency of scores calculated on the STARE dataset using
118118
the score-of-means aggregation.
@@ -257,7 +257,7 @@ def check_stare_vessel_image(
257257
eps,
258258
*,
259259
numerical_tolerance: float = NUMERICAL_TOLERANCE,
260-
):
260+
) -> dict:
261261
"""
262262
Testing the scores calculated for one image of the STARE dataset
263263

mlscorecheck/check/bundles/skinlesion/_isic2016.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__all__ = ["check_isic2016"]
1010

1111

12-
def check_isic2016(*, scores: dict, eps: float, numerical_tolerance: float = NUMERICAL_TOLERANCE):
12+
def check_isic2016(*, scores: dict, eps: float, numerical_tolerance: float = NUMERICAL_TOLERANCE) -> dict:
1313
"""
1414
Tests if the scores are consistent with the test set of the ISIC2016
1515
melanoma classification dataset

mlscorecheck/check/bundles/skinlesion/_isic2017.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__all__ = ["check_isic2017", "_prepare_testset_isic2017"]
1010

1111

12-
def _prepare_testset_isic2017(target, against):
12+
def _prepare_testset_isic2017(target: str | list, against: str | list | None) -> dict:
1313
"""
1414
Preperation of the test set
1515
@@ -25,7 +25,12 @@ def _prepare_testset_isic2017(target, against):
2525
data = get_experiment("skinlesion.isic2017")
2626

2727
target = [target] if isinstance(target, str) else target
28-
against = [against] if isinstance(against, str) else against
28+
29+
if against is None:
30+
all_classes = ['M', 'SK', 'N']
31+
against = [cls for cls in all_classes if cls not in target]
32+
else:
33+
against = [against] if isinstance(against, str) else against
2934

3035
mapping = {"M": "melanoma", "SK": "seborrheic keratosis", "N": "nevus"}
3136

@@ -36,8 +41,13 @@ def _prepare_testset_isic2017(target, against):
3641

3742

3843
def check_isic2017(
39-
*, target, against, scores: dict, eps: float, numerical_tolerance: float = NUMERICAL_TOLERANCE
40-
):
44+
target: str,
45+
scores: dict,
46+
eps,
47+
*,
48+
against: str | None = None,
49+
numerical_tolerance: float = NUMERICAL_TOLERANCE,
50+
) -> dict:
4151
"""
4252
Tests if the scores are consistent with the test set of the ISIC2017
4353
skin lesion classification dataset. The dataset contains three classes,

tests/aggregated/_evaluate_lp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313

1414
def evaluate_timeout(
15-
result: pl.LpProblem, problem: Experiment, scores: dict, eps, score_subset: list
16-
):
15+
result: pl.LpProblem, problem: Experiment, scores: dict, eps, score_subset: list[str]
16+
) -> None:
1717
"""
1818
Evaluate the stopped or succeeded tests
1919

tests/aggregated/test_evaluation.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from mlscorecheck.aggregated import (
1111
Evaluation,
12+
Experiment,
1213
compare_scores,
1314
generate_dataset,
1415
generate_evaluation,
@@ -44,7 +45,7 @@
4445
random_seeds = list(range(5))
4546

4647

47-
def test_evaluate_timeout():
48+
def test_evaluate_timeout() -> None:
4849
"""
4950
Testing the evaluate_timeout function
5051
"""
@@ -54,22 +55,28 @@ class Mock: # pylint: disable=too-few-public-methods
5455
Mock lp_problem class
5556
"""
5657

57-
def __init__(self):
58+
def __init__(self) -> None:
5859
"""
5960
Constructor of the mock class
6061
"""
6162
self.status = 0
6263

6364
mock = Mock()
6465

66+
# Create dummy objects for testing - need Experiment, not Evaluation
67+
dummy_evaluation_dict = generate_evaluation(random_state=42)
68+
dummy_experiment = Experiment(evaluations=[dummy_evaluation_dict], aggregation="som")
69+
dummy_scores: dict = {"acc": 0.5}
70+
dummy_subset: list[str] = ["acc"]
71+
6572
with warnings.catch_warnings(record=True) as warn:
66-
evaluate_timeout(mock, None, None, None, None)
73+
evaluate_timeout(mock, dummy_experiment, dummy_scores, 0.1, dummy_subset)
6774
assert len(warn) == 1
6875

6976

7077
@pytest.mark.parametrize("random_seed", random_seeds)
7178
@pytest.mark.parametrize("aggregation", ["mos", "som"])
72-
def test_instantiation(random_seed: int, aggregation: str):
79+
def test_instantiation(random_seed: int, aggregation: str) -> None:
7380
"""
7481
Testing the instantiation of evaluations
7582
@@ -95,7 +102,7 @@ def test_instantiation(random_seed: int, aggregation: str):
95102

96103
@pytest.mark.parametrize("random_seed", random_seeds)
97104
@pytest.mark.parametrize("aggregation", ["mos", "som"])
98-
def test_sample_figures(random_seed: int, aggregation: str):
105+
def test_sample_figures(random_seed: int, aggregation: str) -> None:
99106
"""
100107
Testing the sampling of figures
101108
@@ -119,8 +126,8 @@ def test_sample_figures(random_seed: int, aggregation: str):
119126
@pytest.mark.parametrize("aggregation", ["mos", "som"])
120127
@pytest.mark.parametrize("rounding_decimals", [2, 3, 4])
121128
def test_linear_programming_success(
122-
subset: list, random_seed: int, aggregation: str, rounding_decimals: int
123-
):
129+
subset: list[str], random_seed: int, aggregation: str, rounding_decimals: int
130+
) -> None:
124131
"""
125132
Testing the linear programming functionalities
126133
@@ -163,8 +170,8 @@ def test_linear_programming_success(
163170
@pytest.mark.parametrize("aggregation", ["mos", "som"])
164171
@pytest.mark.parametrize("rounding_decimals", [2, 3, 4])
165172
def test_linear_programming_evaluation_generation_success(
166-
subset: list, random_seed: int, aggregation: str, rounding_decimals: int
167-
):
173+
subset: list[str], random_seed: int, aggregation: str, rounding_decimals: int
174+
) -> None:
168175
"""
169176
Testing the linear programming functionalities by generating the evaluation
170177
@@ -175,9 +182,15 @@ def test_linear_programming_evaluation_generation_success(
175182
rounding_decimals (int): the number of decimals to round to
176183
"""
177184

178-
evaluation = generate_evaluation(random_state=random_seed, aggregation=aggregation)
185+
evaluation_dict = generate_evaluation(random_state=random_seed, aggregation=aggregation)
186+
assert isinstance(evaluation_dict, dict), "generate_evaluation should return dict when return_scores=False"
179187

180-
evaluation = Evaluation(**evaluation)
188+
evaluation = Evaluation(
189+
dataset=evaluation_dict["dataset"],
190+
folding=evaluation_dict["folding"],
191+
aggregation=evaluation_dict["aggregation"],
192+
fold_score_bounds=evaluation_dict.get("fold_score_bounds"),
193+
)
181194

182195
evaluation.sample_figures(random_state=random_seed)
183196

@@ -203,7 +216,7 @@ def test_linear_programming_evaluation_generation_success(
203216
@pytest.mark.parametrize("aggregation", ["mos", "som"])
204217
def test_linear_programming_evaluation_generation_failure(
205218
random_seed: int, aggregation: str
206-
):
219+
) -> None:
207220
"""
208221
Testing the linear programming functionalities by generating the evaluation
209222
@@ -212,9 +225,15 @@ def test_linear_programming_evaluation_generation_failure(
212225
aggregation (str): the aggregation to use ('mos'/'som')
213226
"""
214227

215-
evaluation = generate_evaluation(random_state=random_seed, aggregation=aggregation)
228+
evaluation_dict = generate_evaluation(random_state=random_seed, aggregation=aggregation)
229+
assert isinstance(evaluation_dict, dict), "generate_evaluation should return dict when return_scores=False"
216230

217-
evaluation = Evaluation(**evaluation)
231+
evaluation = Evaluation(
232+
dataset=evaluation_dict["dataset"],
233+
folding=evaluation_dict["folding"],
234+
aggregation=evaluation_dict["aggregation"],
235+
fold_score_bounds=evaluation_dict.get("fold_score_bounds"),
236+
)
218237

219238
evaluation.sample_figures(random_state=random_seed)
220239

@@ -229,7 +248,7 @@ def test_linear_programming_evaluation_generation_failure(
229248

230249
@pytest.mark.parametrize("random_seed", random_seeds)
231250
@pytest.mark.parametrize("aggregation", ["mos", "som"])
232-
def test_get_fold_score_bounds(random_seed: int, aggregation: str):
251+
def test_get_fold_score_bounds(random_seed: int, aggregation: str) -> None:
233252
"""
234253
Testing the extraction of fold score bounds
235254
@@ -238,9 +257,15 @@ def test_get_fold_score_bounds(random_seed: int, aggregation: str):
238257
aggregation (str): the aggregation to use ('mos'/'som')
239258
"""
240259

241-
evaluation = generate_evaluation(random_state=random_seed, aggregation=aggregation)
260+
evaluation_dict = generate_evaluation(random_state=random_seed, aggregation=aggregation)
261+
assert isinstance(evaluation_dict, dict), "generate_evaluation should return dict when return_scores=False"
242262

243-
evaluation = Evaluation(**evaluation)
263+
evaluation = Evaluation(
264+
dataset=evaluation_dict["dataset"],
265+
folding=evaluation_dict["folding"],
266+
aggregation=evaluation_dict["aggregation"],
267+
fold_score_bounds=evaluation_dict.get("fold_score_bounds"),
268+
)
244269
evaluation.sample_figures().calculate_scores()
245270

246271
score_bounds = get_fold_score_bounds(evaluation, feasible=True)
@@ -255,8 +280,8 @@ def test_get_fold_score_bounds(random_seed: int, aggregation: str):
255280
@pytest.mark.parametrize("aggregation", ["mos"])
256281
@pytest.mark.parametrize("rounding_decimals", [3, 4])
257282
def test_linear_programming_success_bounds(
258-
subset: list, random_seed: int, aggregation: str, rounding_decimals: int
259-
):
283+
subset: list[str], random_seed: int, aggregation: str, rounding_decimals: int
284+
) -> None:
260285
"""
261286
Testing the linear programming functionalities by generating the evaluation
262287
with bounds
@@ -287,16 +312,22 @@ def test_linear_programming_success_bounds(
287312

288313
assert lp_program.status in (0, 1)
289314

290-
evaluate_timeout(lp_program, skeleton, scores, 10 ** (-rounding_decimals), subset)
315+
# Direct evaluation instead of evaluate_timeout since we have an Evaluation, not Experiment
316+
if lp_program.status == 1:
317+
populated = skeleton.populate(lp_program)
318+
assert compare_scores(
319+
scores, populated.calculate_scores(), 10 ** (-rounding_decimals), subset
320+
)
321+
assert populated.check_bounds()["bounds_flag"] is True
291322

292323

293324
@pytest.mark.parametrize("subset", two_combs + three_combs + four_combs)
294325
@pytest.mark.parametrize("random_seed", random_seeds)
295326
@pytest.mark.parametrize("aggregation", ["mos"])
296327
@pytest.mark.parametrize("rounding_decimals", [3, 4])
297328
def test_linear_programming_failure_bounds(
298-
subset: list, random_seed: int, aggregation: str, rounding_decimals: int
299-
):
329+
subset: list[str], random_seed: int, aggregation: str, rounding_decimals: int
330+
) -> None:
300331
"""
301332
Testing the linear programming functionalities by generating the evaluation
302333
with bounds
@@ -327,16 +358,23 @@ def test_linear_programming_failure_bounds(
327358

328359
assert lp_program.status in (-1, 0)
329360

330-
evaluate_timeout(lp_program, skeleton, scores, 10 ** (-rounding_decimals), subset)
361+
# Direct evaluation instead of evaluate_timeout since we have an Evaluation, not Experiment
362+
# For infeasible problems, just check the status
331363

332364

333-
def test_others():
365+
def test_others() -> None:
334366
"""
335367
Testing other functionalities
336368
"""
337369

338-
evaluation = generate_evaluation(aggregation="som",
370+
evaluation_dict = generate_evaluation(aggregation="som",
339371
feasible_fold_score_bounds=True,
340372
random_state=5)
373+
assert isinstance(evaluation_dict, dict), "generate_evaluation should return dict when return_scores=False"
341374
with pytest.raises(ValueError):
342-
Evaluation(**evaluation)
375+
Evaluation(
376+
dataset=evaluation_dict["dataset"],
377+
folding=evaluation_dict["folding"],
378+
aggregation=evaluation_dict["aggregation"],
379+
fold_score_bounds=evaluation_dict.get("fold_score_bounds"),
380+
)

0 commit comments

Comments
 (0)