Skip to content

Commit dc25358

Browse files
committed
refactor: track model_ids in cv_results
1 parent a978478 commit dc25358

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

autosklearn/automl.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,15 +1921,17 @@ def cv_results_(self):
19211921
metric_dict[metric.name] = []
19221922
metric_mask[metric.name] = []
19231923

1924+
model_ids = []
19241925
mean_fit_time = []
19251926
params = []
19261927
status = []
19271928
budgets = []
19281929

1929-
for run_key in self.runhistory_.data:
1930-
run_value = self.runhistory_.data[run_key]
1930+
for run_key, run_value in self.runhistory_.data.items():
19311931
config_id = run_key.config_id
19321932
config = self.runhistory_.ids_config[config_id]
1933+
if run_value.additional_info and "num_run" in run_value.additional_info:
1934+
model_ids.append(run_value.additional_info["num_run"])
19331935

19341936
s = run_value.status
19351937
if s == StatusType.SUCCESS:
@@ -1990,6 +1992,8 @@ def cv_results_(self):
19901992
metric_dict[metric.name].append(metric_value)
19911993
metric_mask[metric.name].append(mask_value)
19921994

1995+
results["model_ids"] = model_ids
1996+
19931997
if len(self._metrics) == 1:
19941998
results["mean_test_score"] = np.array(metric_dict[self._metrics[0].name])
19951999
rank_order = -1 * self._metrics[0]._sign * results["mean_test_score"]
@@ -2165,14 +2169,11 @@ def show_models(self) -> dict[int, Any]:
21652169
warnings.warn("No ensemble found. Returning empty dictionary.")
21662170
return ensemble_dict
21672171

2168-
def has_key(rv, key):
2169-
return rv.additional_info and key in rv.additional_info
2170-
21712172
table_dict = {}
2172-
for run_key, run_val in self.runhistory_.data.items():
2173-
if has_key(run_val, "num_run"):
2174-
model_id = run_val.additional_info["num_run"]
2175-
table_dict[model_id] = {"model_id": model_id, "cost": run_val.cost}
2173+
for run_key, run_value in self.runhistory_.data.items():
2174+
if run_value.additional_info and "num_run" in run_value.additional_info:
2175+
model_id = run_value.additional_info["num_run"]
2176+
table_dict[model_id] = {"model_id": model_id, "cost": run_value.cost}
21762177

21772178
# Checking if the dictionary is empty
21782179
if not table_dict:

0 commit comments

Comments
 (0)