|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from typing import Any
|
2 | 4 |
|
3 | 5 | import datetime
|
|
10 | 12 | class ProgressBar(Thread):
|
11 | 13 | """A Thread that displays a tqdm progress bar in the console.
|
12 | 14 |
|
13 |
| - It is specialized to display information relevant to fitting to the training data |
14 |
| - with auto-sklearn. |
| 15 | + Treat this class as an ordinary thread. So to display a progress bar, |
| 16 | + call start() on an instance of this class. To wait for the thread to |
| 17 | + terminate call join(), which will max out the progress bar, |
| 18 | + therefore terminate this thread immediately. |
15 | 19 |
|
16 | 20 | Parameters
|
17 | 21 | ----------
|
18 | 22 | total : int
|
19 |
| - The total amount that should be reached by the progress bar once it finishes |
20 |
| - update_interval : float |
21 |
| - Specifies how frequently the progress bar is updated (in seconds) |
22 |
| - disable : bool |
23 |
| - Turns on or off the progress bar. If True, this thread won't be started or |
24 |
| - initialized. |
25 |
| - kwargs : Any |
| 23 | + The total amount that should be reached by the progress bar once it finishes. |
| 24 | + update_interval : float, default=1.0 |
| 25 | + Specifies how frequently the progress bar is updated (in seconds). |
| 26 | + disable : bool, default=False |
| 27 | + Turns on or off the progress bar. If True, this thread does not get |
| 28 | + initialized and won't be started if start() is called. |
| 29 | + tqdm_kwargs : Any, optional |
26 | 30 | Keyword arguments that are passed into tqdm's constructor. Refer to:
|
27 |
| - `tqdm <https://tqdm.github.io/docs/tqdm/>`_. Note that postfix can not be |
28 |
| - specified in the kwargs since it is already passed into tqdm by this class. |
| 31 | + `tqdm <https://tqdm.github.io/docs/tqdm/>`_ for a list of parameters that |
| 32 | + tqdm accepts. Note that 'postfix' cannot be specified in the kwargs since it is |
| 33 | + already passed into tqdm by this class. |
| 34 | +
|
| 35 | + Examples |
| 36 | + -------- |
| 37 | +
|
| 38 | + .. code:: python |
| 39 | +
|
| 40 | + progress_bar = ProgressBar( |
| 41 | + total=10, |
| 42 | + desc="Executing code that runs for 10 seconds", |
| 43 | + colour="green", |
| 44 | + ) |
| 45 | + # colour is a tqdm parameter passed as a tqdm_kwargs |
| 46 | + try: |
| 47 | + progress_bar.start() |
| 48 | + # some code that runs for 10 seconds |
| 49 | + except SomeException: |
| 50 | + # something went wrong |
| 51 | + finally: |
| 52 | + progress_bar.join() |
| 53 | + # perform some cleanup |
29 | 54 | """
|
30 | 55 |
|
31 | 56 | def __init__(
|
32 | 57 | self,
|
33 | 58 | total: int,
|
34 | 59 | update_interval: float = 1.0,
|
35 | 60 | disable: bool = False,
|
36 |
| - **kwargs: Any, |
| 61 | + **tqdm_kwargs: Any, |
37 | 62 | ):
|
38 | 63 | self.disable = disable
|
39 | 64 | if not disable:
|
40 | 65 | super().__init__(name="_progressbar_")
|
41 | 66 | self.total = total
|
42 | 67 | self.update_interval = update_interval
|
43 | 68 | self.terminated: bool = False
|
44 |
| - self.kwargs = kwargs |
45 |
| - # start this thread |
46 |
| - self.start() |
| 69 | + self.tqdm_kwargs = tqdm_kwargs |
47 | 70 |
|
48 |
| - def run(self) -> None: |
49 |
| - """Display a tqdm progress bar in the console. |
| 71 | + def start(self) -> None: |
| 72 | + """Start a new thread that calls the run() method.""" |
| 73 | + if not self.disable: |
| 74 | + super().start() |
50 | 75 |
|
51 |
| - Additionally, it shows useful information related to the task. This method |
52 |
| - overrides the run method of Thread. |
53 |
| - """ |
| 76 | + def run(self) -> None: |
| 77 | + """Display a tqdm progress bar in the console.""" |
54 | 78 | if not self.disable:
|
55 | 79 | for _ in trange(
|
56 | 80 | self.total,
|
57 | 81 | postfix=f"The total time budget for this task is "
|
58 | 82 | f"{datetime.timedelta(seconds=self.total)}",
|
59 |
| - **self.kwargs, |
| 83 | + **self.tqdm_kwargs, |
60 | 84 | ):
|
61 | 85 | if not self.terminated:
|
62 | 86 | time.sleep(self.update_interval)
|
63 | 87 |
|
64 |
| - def stop(self) -> None: |
65 |
| - """Terminates the thread.""" |
| 88 | + def join(self, timeout: float | None = None) -> None: |
| 89 | + """Maxes out the progress bar and thereby terminating this thread.""" |
66 | 90 | if not self.disable:
|
67 | 91 | self.terminated = True
|
68 |
| - super().join() |
| 92 | + super().join(timeout) |
0 commit comments