Skip to content

Commit d8a863c

Browse files
committed
Fix unittest?, increase coverage (hopefully)
1 parent b01f1cb commit d8a863c

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

autosklearn/automl.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,11 +2010,17 @@ def sprint_statistics(self) -> str:
20102010
)
20112011
)[0]
20122012
if len(idx_success) > 0:
2013+
key = (
2014+
"mean_test_score"
2015+
if len(self._metrics) == 1
2016+
else f"mean_test_" f"{self._metrics[0].name}"
2017+
)
2018+
20132019
if not self._metrics[0]._optimum:
2014-
idx_best_run = np.argmin(cv_results["mean_test_score"][idx_success])
2020+
idx_best_run = np.argmin(cv_results[key][idx_success])
20152021
else:
2016-
idx_best_run = np.argmax(cv_results["mean_test_score"][idx_success])
2017-
best_score = cv_results["mean_test_score"][idx_success][idx_best_run]
2022+
idx_best_run = np.argmax(cv_results[key][idx_success])
2023+
best_score = cv_results[key][idx_success][idx_best_run]
20182024
sio.write(" Best validation score: %f\n" % best_score)
20192025
num_runs = len(cv_results["status"])
20202026
sio.write(" Number of target algorithm runs: %d\n" % num_runs)

examples/40_advanced/example_metrics.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
8181
scorer = autosklearn.metrics.accuracy
8282
cls = autosklearn.classification.AutoSklearnClassifier(
8383
time_left_for_this_task=60,
84-
per_run_time_limit=30,
8584
seed=1,
8685
metric=scorer,
8786
)
@@ -107,7 +106,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
107106
)
108107
cls = autosklearn.classification.AutoSklearnClassifier(
109108
time_left_for_this_task=60,
110-
per_run_time_limit=30,
111109
seed=1,
112110
metric=accuracy_scorer,
113111
)
@@ -133,7 +131,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
133131
)
134132
cls = autosklearn.classification.AutoSklearnClassifier(
135133
time_left_for_this_task=60,
136-
per_run_time_limit=30,
137134
seed=1,
138135
metric=error_rate,
139136
)
@@ -184,7 +181,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
184181
)
185182
cls = autosklearn.classification.AutoSklearnClassifier(
186183
time_left_for_this_task=60,
187-
per_run_time_limit=30,
188184
seed=1,
189185
metric=error_rate,
190186
)
@@ -217,10 +213,8 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
217213
)
218214
cls = autosklearn.classification.AutoSklearnClassifier(
219215
time_left_for_this_task=60,
220-
per_run_time_limit=30,
221216
seed=1,
222217
metric=accuracy_scorer,
223-
ensemble_size=0,
224218
)
225219
cls.fit(X_train, y_train)
226220

@@ -232,4 +226,4 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
232226
consider_col=1,
233227
val_threshold=18.8,
234228
)
235-
print(f"Error score {score:.3f} using {error_rate.name:s}")
229+
print(f"Error score {score:.3f} using {accuracy_scorer.name:s}")

test/test_automl/test_post_fit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def test__load_pareto_front(automl: AutoML) -> None:
8686
"""
8787
# Check that the predict function works
8888
X = np.array([[1.0, 1.0, 1.0, 1.0]])
89-
print(automl.predict(X))
89+
9090
assert automl.predict_proba(X).shape == (1, 3)
9191
assert automl.predict(X).shape == (1,)
9292

@@ -98,3 +98,9 @@ def test__load_pareto_front(automl: AutoML) -> None:
9898
assert y_pred.shape == (1, 3)
9999
y_pred = ensemble.predict(X)
100100
assert y_pred in ["setosa", "versicolor", "virginica"]
101+
102+
statistics = automl.sprint_statistics()
103+
assert "Metrics" in statistics
104+
assert ("Best validation score: 0.9" in statistics) or (
105+
"Best validation score: 1.0" in statistics
106+
), statistics

test/test_ensemble_builder/test_ensemble_builder_real.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_run_builds_valid_ensemble(builder: EnsembleBuilder) -> None:
7878

7979
assert mock_fit.call_count == 1
8080
# Check that the ids of runs in the ensemble were all candidates
81-
candidates = mock_fit.call_args.kwargs["candidates"]
81+
candidates = mock_fit.call_args[1]["candidates"]
8282
candidate_ids = {run.id for run in candidates}
8383
assert ensemble_ids <= candidate_ids
8484

0 commit comments

Comments
 (0)