Skip to content

Commit 8d5ce40

Browse files
committed
fix update
1 parent 91ac8ee commit 8d5ce40

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

snippets/decorators.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
1616
from functools import wraps
1717
from typing import Generator, Iterable, List, Tuple
18+
from venv import logger
1819

1920

2021
from loguru import logger as default_logger
@@ -160,20 +161,33 @@ def wrapped(*args, **kwargs):
160161

161162

162163
# 线程池批量跑function
163-
def multi_thread(work_num, return_list=False, use_process=False):
164+
def multi_thread(work_num, return_list=False, safe_execute=True):
165+
"""多线程跑某个function的装饰器
166+
Args:
167+
work_num (_type_): 线程数
168+
return_list (bool, optional): 结果是否返回list. Defaults to False.
169+
safe_execute (bool, optional): 是不是catch住function中的exception并返回None. Defaults to True.
170+
"""
164171
def wrapper(func):
165172
@wraps(func)
166173
def wrapped(data: Iterable, *args, **kwargs):
167174
def _func(x):
168-
return func(x, *args, **kwargs)
169-
with ThreadPoolExecutor(work_num) as executors:
170-
rs_iter = executors.map(_func, data)
171-
rs = rs_iter
172-
total = None if not hasattr(data, '__len__') else len(data)
173-
rs = tqdm(rs_iter, total=total)
174-
if return_list:
175-
return list(rs)
176-
return rs
175+
try:
176+
return func(x, *args, **kwargs)
177+
except Exception as e:
178+
if safe_execute:
179+
logger.warning(f"function {func.__name__} failed with exception")
180+
return None
181+
else:
182+
raise e
183+
184+
executors = ThreadPoolExecutor(work_num)
185+
rs_iter = executors.map(_func, data)
186+
total = None if not hasattr(data, '__len__') else len(data)
187+
rs_iter = tqdm(rs_iter, total=total)
188+
rs_iter = (e for e in rs_iter if e is not None)
189+
190+
return list(rs_iter) if return_list else rs_iter
177191
return wrapped
178192
return wrapper
179193

@@ -184,14 +198,13 @@ def multi_process(work_num, return_list=False):
184198
def wrapper(func):
185199
@wraps(func)
186200
def wrapped(data: Iterable):
187-
with ProcessPoolExecutor(work_num) as executors:
188-
rs_iter = executors.map(func, data)
189-
rs = rs_iter
190-
total = None if not hasattr(data, '__len__') else len(data)
191-
rs = tqdm(rs_iter, total=total)
192-
if return_list:
193-
return list(rs)
194-
return rs
201+
executors = ProcessPoolExecutor(work_num)
202+
rs_iter = executors.map(func, data)
203+
rs = rs_iter
204+
total = None if not hasattr(data, '__len__') else len(data)
205+
rs_iter = tqdm(rs_iter, total=total)
206+
return list(rs_iter) if return_list else rs
207+
195208
return wrapped
196209
return wrapper
197210

tests/test_decorators.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ def add(a, b=1, sleep=False):
2020

2121
return a+b
2222

23+
def sleep_with_add(a):
24+
return add(a, sleep=True)
25+
26+
2327

2428
class TestUtils(unittest.TestCase):
2529
def test_adapt_single(self):
@@ -61,7 +65,7 @@ def test_multi_thread(self):
6165
self.assertListEqual([3, 4, 5, 6, 7, 8, 9, 10, 11,12], rs_list)
6266

6367
def test_multi_process(self):
64-
process_batch_fn = batch_process(work_num=4, return_list=True)(add)
68+
process_batch_fn = multi_process(work_num=4, return_list=True)(sleep_with_add)
6569
rs = process_batch_fn(data=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
6670
print(rs)
6771
self.assertListEqual([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], rs)

0 commit comments

Comments
 (0)