Skip to content

refactor: track model_id in cv_results #1627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def fit(
# space
self._backend.save_start_time(self._seed)

progress_bar.start()
self._stopwatch = StopWatch()

# Make sure that input is valid
Expand Down Expand Up @@ -970,7 +971,7 @@ def fit(
self._logger.exception(e)
raise e
finally:
progress_bar.stop()
progress_bar.join()
self._fit_cleanup()

self.fitted = True
Expand Down Expand Up @@ -1920,15 +1921,17 @@ def cv_results_(self):
metric_dict[metric.name] = []
metric_mask[metric.name] = []

model_ids = []
mean_fit_time = []
params = []
status = []
budgets = []

for run_key in self.runhistory_.data:
run_value = self.runhistory_.data[run_key]
for run_key, run_value in self.runhistory_.data.items():
config_id = run_key.config_id
config = self.runhistory_.ids_config[config_id]
if run_value.additional_info and "num_run" in run_value.additional_info:
model_ids.append(run_value.additional_info["num_run"])

s = run_value.status
if s == StatusType.SUCCESS:
Expand Down Expand Up @@ -1989,6 +1992,8 @@ def cv_results_(self):
metric_dict[metric.name].append(metric_value)
metric_mask[metric.name].append(mask_value)

results["model_ids"] = model_ids

if len(self._metrics) == 1:
results["mean_test_score"] = np.array(metric_dict[self._metrics[0].name])
rank_order = -1 * self._metrics[0]._sign * results["mean_test_score"]
Expand Down Expand Up @@ -2164,14 +2169,11 @@ def show_models(self) -> dict[int, Any]:
warnings.warn("No ensemble found. Returning empty dictionary.")
return ensemble_dict

def has_key(rv, key):
return rv.additional_info and key in rv.additional_info

table_dict = {}
for run_key, run_val in self.runhistory_.data.items():
if has_key(run_val, "num_run"):
model_id = run_val.additional_info["num_run"]
table_dict[model_id] = {"model_id": model_id, "cost": run_val.cost}
for run_key, run_value in self.runhistory_.data.items():
if run_value.additional_info and "num_run" in run_value.additional_info:
model_id = run_value.additional_info["num_run"]
table_dict[model_id] = {"model_id": model_id, "cost": run_value.cost}

# Checking if the dictionary is empty
if not table_dict:
Expand Down
72 changes: 48 additions & 24 deletions autosklearn/util/progress_bar.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import Any

import datetime
Expand All @@ -10,59 +12,81 @@
class ProgressBar(Thread):
"""A Thread that displays a tqdm progress bar in the console.

It is specialized to display information relevant to fitting to the training data
with auto-sklearn.
Treat this class as an ordinary thread. So to display a progress bar,
call start() on an instance of this class. To wait for the thread to
terminate call join(), which will max out the progress bar,
therefore terminate this thread immediately.

Parameters
----------
total : int
The total amount that should be reached by the progress bar once it finishes
update_interval : float
Specifies how frequently the progress bar is updated (in seconds)
disable : bool
Turns on or off the progress bar. If True, this thread won't be started or
initialized.
kwargs : Any
The total amount that should be reached by the progress bar once it finishes.
update_interval : float, default=1.0
Specifies how frequently the progress bar is updated (in seconds).
disable : bool, default=False
Turns on or off the progress bar. If True, this thread does not get
initialized and won't be started if start() is called.
tqdm_kwargs : Any, optional
Keyword arguments that are passed into tqdm's constructor. Refer to:
`tqdm <https://tqdm.github.io/docs/tqdm/>`_. Note that postfix can not be
specified in the kwargs since it is already passed into tqdm by this class.
`tqdm <https://tqdm.github.io/docs/tqdm/>`_ for a list of parameters that
tqdm accepts. Note that 'postfix' cannot be specified in the kwargs since it is
already passed into tqdm by this class.

Examples
--------

.. code:: python

progress_bar = ProgressBar(
total=10,
desc="Executing code that runs for 10 seconds",
colour="green",
)
# colour is a tqdm parameter passed as a tqdm_kwargs
try:
progress_bar.start()
# some code that runs for 10 seconds
except SomeException:
# something went wrong
finally:
progress_bar.join()
# perform some cleanup
"""

def __init__(
self,
total: int,
update_interval: float = 1.0,
disable: bool = False,
**kwargs: Any,
**tqdm_kwargs: Any,
):
self.disable = disable
if not disable:
super().__init__(name="_progressbar_")
self.total = total
self.update_interval = update_interval
self.terminated: bool = False
self.kwargs = kwargs
# start this thread
self.start()
self.tqdm_kwargs = tqdm_kwargs

def run(self) -> None:
"""Display a tqdm progress bar in the console.
def start(self) -> None:
"""Start a new thread that calls the run() method."""
if not self.disable:
super().start()

Additionally, it shows useful information related to the task. This method
overrides the run method of Thread.
"""
def run(self) -> None:
"""Display a tqdm progress bar in the console."""
if not self.disable:
for _ in trange(
self.total,
postfix=f"The total time budget for this task is "
f"{datetime.timedelta(seconds=self.total)}",
**self.kwargs,
**self.tqdm_kwargs,
):
if not self.terminated:
time.sleep(self.update_interval)

def stop(self) -> None:
"""Terminates the thread."""
def join(self, timeout: float | None = None) -> None:
"""Maxes out the progress bar and thereby terminating this thread."""
if not self.disable:
self.terminated = True
super().join()
super().join(timeout)