Skip to content

Commit d61c6d7

Browse files
committed
Add progress bar with tqdm.
1 parent d67e697 commit d61c6d7

File tree

4 files changed

+48
-4
lines changed

4 files changed

+48
-4
lines changed

labelled_functions/decorators.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,25 @@ def timed_func(*args, **kwargs):
5050
return timed_func
5151

5252

53+
def with_progress_bar(lab_f, total=None):
54+
"""Add a tqdm object to count calls and display a progress bar."""
55+
from tqdm import tqdm
56+
bar = tqdm(total=total, unit="calls")
57+
58+
def add_call_counter(f):
59+
def exec_with_call_counter(*args, **kwargs):
60+
out = f(*args, **kwargs)
61+
bar.update(1)
62+
return out
63+
return exec_with_call_counter
64+
65+
new_lab_f = copy(lab_f)
66+
new_lab_f.bar = bar
67+
new_lab_f.function = add_call_counter(lab_f.function)
68+
69+
return new_lab_f
70+
71+
5372
def decorate(func, decorator):
5473
new_func = copy(func)
5574
new_func.function = decorator(func.function)

labelled_functions/maps.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
import xarray as xr
99

1010
from .labels import label
11-
from .decorators import keeping_inputs
11+
from .decorators import keeping_inputs, with_progress_bar
1212

1313

1414
# API
1515

16-
def pandas_map(f, *args, n_jobs=1, **kwargs):
16+
def pandas_map(f, *args, progress_bar=False, n_jobs=1, **kwargs):
1717
f = label(f)
1818
dict_of_lists = _preprocess_map_inputs(f.input_names, args, kwargs)
19+
if progress_bar:
20+
f = with_progress_bar(f, total=len(any_value(dict_of_lists)))
1921
if n_jobs == 1:
2022
data = list(lmap(keeping_inputs(f), **dict_of_lists))
2123
else:
@@ -25,9 +27,11 @@ def pandas_map(f, *args, n_jobs=1, **kwargs):
2527
return _set_index(f.input_names, data)
2628

2729

28-
def pandas_cartesian_product(f, *args, n_jobs=1, **kwargs):
30+
def pandas_cartesian_product(f, *args, progress_bar=False, n_jobs=1, **kwargs):
2931
f = label(f)
3032
dict_of_lists = _preprocess_map_inputs(f.input_names, args, kwargs)
33+
if progress_bar:
34+
f = with_progress_bar(f, total=len(any_value(dict_of_lists)))
3135
if n_jobs == 1:
3236
data = list(lcartesianmap(keeping_inputs(f), **dict_of_lists))
3337
else:
@@ -91,3 +95,7 @@ def _set_index(indices, data):
9195
return data.set_index(indices)
9296
else:
9397
return data
98+
99+
def any_value(d):
100+
"""Returns one of the values of a dict."""
101+
return next(iter(d.values()))

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@
2222
'toolz',
2323
'parso',
2424
'joblib',
25+
'tqdm',
2526
],
2627
)

test/test_decorators.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
import pytest
55

6+
from time import sleep
7+
68
from labelled_functions import label, pipeline
7-
from labelled_functions.decorators import keeping_inputs, decorate, cache
9+
from labelled_functions.maps import pandas_map
10+
from labelled_functions.decorators import keeping_inputs, with_progress_bar, decorate, cache
811
from example_functions import *
912

1013

@@ -40,3 +43,16 @@ def test_cache():
4043
a = pipe(length=1.0)
4144
b = pipe(length=1.0)
4245
assert a == b
46+
47+
48+
def test_progress_bar(capfd):
49+
def wait(dt):
50+
sleep(dt)
51+
output = 1
52+
return output
53+
54+
pandas_map(wait, dt=[0.01]*10, progress_bar=True)
55+
56+
out, err = capfd.readouterr()
57+
assert "10/10" in err
58+

0 commit comments

Comments
 (0)