Skip to content

Commit 59ea4b0

Browse files
authored
refactor: use progress_bar more explicitly as a thread (#1622)
1 parent 1abd1f9 commit 59ea4b0

File tree

2 files changed

+50
-25
lines changed

2 files changed

+50
-25
lines changed

autosklearn/automl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def fit(
652652
# space
653653
self._backend.save_start_time(self._seed)
654654

655+
progress_bar.start()
655656
self._stopwatch = StopWatch()
656657

657658
# Make sure that input is valid
@@ -970,7 +971,7 @@ def fit(
970971
self._logger.exception(e)
971972
raise e
972973
finally:
973-
progress_bar.stop()
974+
progress_bar.join()
974975
self._fit_cleanup()
975976

976977
self.fitted = True

autosklearn/util/progress_bar.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from typing import Any
24

35
import datetime
@@ -10,59 +12,81 @@
1012
class ProgressBar(Thread):
1113
"""A Thread that displays a tqdm progress bar in the console.
1214
13-
It is specialized to display information relevant to fitting to the training data
14-
with auto-sklearn.
15+
Treat this class as an ordinary thread. So to display a progress bar,
16+
call start() on an instance of this class. To wait for the thread to
17+
terminate call join(), which will max out the progress bar,
18+
therefore terminate this thread immediately.
1519
1620
Parameters
1721
----------
1822
total : int
19-
The total amount that should be reached by the progress bar once it finishes
20-
update_interval : float
21-
Specifies how frequently the progress bar is updated (in seconds)
22-
disable : bool
23-
Turns on or off the progress bar. If True, this thread won't be started or
24-
initialized.
25-
kwargs : Any
23+
The total amount that should be reached by the progress bar once it finishes.
24+
update_interval : float, default=1.0
25+
Specifies how frequently the progress bar is updated (in seconds).
26+
disable : bool, default=False
27+
Turns on or off the progress bar. If True, this thread does not get
28+
initialized and won't be started if start() is called.
29+
tqdm_kwargs : Any, optional
2630
Keyword arguments that are passed into tqdm's constructor. Refer to:
27-
`tqdm <https://tqdm.github.io/docs/tqdm/>`_. Note that postfix can not be
28-
specified in the kwargs since it is already passed into tqdm by this class.
31+
`tqdm <https://tqdm.github.io/docs/tqdm/>`_ for a list of parameters that
32+
tqdm accepts. Note that 'postfix' cannot be specified in the kwargs since it is
33+
already passed into tqdm by this class.
34+
35+
Examples
36+
--------
37+
38+
.. code:: python
39+
40+
progress_bar = ProgressBar(
41+
total=10,
42+
desc="Executing code that runs for 10 seconds",
43+
colour="green",
44+
)
45+
# colour is a tqdm parameter passed as a tqdm_kwargs
46+
try:
47+
progress_bar.start()
48+
# some code that runs for 10 seconds
49+
except SomeException:
50+
# something went wrong
51+
finally:
52+
progress_bar.join()
53+
# perform some cleanup
2954
"""
3055

3156
def __init__(
3257
self,
3358
total: int,
3459
update_interval: float = 1.0,
3560
disable: bool = False,
36-
**kwargs: Any,
61+
**tqdm_kwargs: Any,
3762
):
3863
self.disable = disable
3964
if not disable:
4065
super().__init__(name="_progressbar_")
4166
self.total = total
4267
self.update_interval = update_interval
4368
self.terminated: bool = False
44-
self.kwargs = kwargs
45-
# start this thread
46-
self.start()
69+
self.tqdm_kwargs = tqdm_kwargs
4770

48-
def run(self) -> None:
49-
"""Display a tqdm progress bar in the console.
71+
def start(self) -> None:
72+
"""Start a new thread that calls the run() method."""
73+
if not self.disable:
74+
super().start()
5075

51-
Additionally, it shows useful information related to the task. This method
52-
overrides the run method of Thread.
53-
"""
76+
def run(self) -> None:
77+
"""Display a tqdm progress bar in the console."""
5478
if not self.disable:
5579
for _ in trange(
5680
self.total,
5781
postfix=f"The total time budget for this task is "
5882
f"{datetime.timedelta(seconds=self.total)}",
59-
**self.kwargs,
83+
**self.tqdm_kwargs,
6084
):
6185
if not self.terminated:
6286
time.sleep(self.update_interval)
6387

64-
def stop(self) -> None:
65-
"""Terminates the thread."""
88+
def join(self, timeout: float | None = None) -> None:
89+
"""Maxes out the progress bar and thereby terminating this thread."""
6690
if not self.disable:
6791
self.terminated = True
68-
super().join()
92+
super().join(timeout)

0 commit comments

Comments
 (0)