Skip to content

Commit b4675ed

Browse files
Merge pull request #33 from python-thread/dev
Threaded Decorator Type-Safety
2 parents 11e6074 + adb2f13 commit b4675ed

File tree

4 files changed

+55
-49
lines changed

4 files changed

+55
-49
lines changed

src/thread/_types.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from typing import Any, Literal, Callable, Union
8+
from typing_extensions import ParamSpec, TypeVar
89

910

1011
# Descriptive Types
@@ -27,5 +28,11 @@
2728

2829

2930
# Function types
30-
HookFunction = Callable[[Data_Out], Union[Any, None]]
31-
TargetFunction = Callable[..., Data_Out]
31+
_Target_P = ParamSpec('_Target_P')
32+
_Target_T = TypeVar('_Target_T')
33+
TargetFunction = Callable[_Target_P, _Target_T]
34+
35+
HookFunction = Callable[[_Target_T], Union[Any, None]]
36+
37+
_Dataset_T = TypeVar('_Dataset_T')
38+
DatasetFunction = Callable[[_Dataset_T], _Target_T]

src/thread/decorators.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
T = TypeVar('T')
1616
P = ParamSpec('P')
17-
TargetFunction = Callable[P, Data_Out]
18-
NoParamReturn = Callable[P, Thread]
19-
WithParamReturn = Callable[[TargetFunction], NoParamReturn]
20-
FullParamReturn = Callable[P, Thread]
21-
WrappedWithParamReturn = Callable[[TargetFunction], WithParamReturn]
17+
TargetFunction = Callable[P, T]
18+
NoParamReturn = Callable[P, Thread[P, T]]
19+
WithParamReturn = Callable[[TargetFunction[P, T]], NoParamReturn[P, T]]
20+
FullParamReturn = Callable[P, Thread[P, T]]
21+
WrappedWithParamReturn = Callable[[TargetFunction[P, T]], WithParamReturn[P, T]]
2222

2323

2424
@overload
25-
def threaded(__function: TargetFunction) -> NoParamReturn: ...
25+
def threaded(__function: TargetFunction[P, T]) -> NoParamReturn[P, T]: ...
2626

2727
@overload
2828
def threaded(
@@ -32,29 +32,29 @@ def threaded(
3232
ignore_errors: Sequence[type[Exception]] = (),
3333
suppress_errors: bool = False,
3434
**overflow_kwargs: Overflow_In
35-
) -> WithParamReturn: ...
35+
) -> WithParamReturn[P, T]: ...
3636

3737
@overload
3838
def threaded(
39-
__function: Callable[P, Data_Out],
39+
__function: TargetFunction[P, T],
4040
*,
4141
args: Sequence[Data_In] = (),
4242
kwargs: Mapping[str, Data_In] = {},
4343
ignore_errors: Sequence[type[Exception]] = (),
4444
suppress_errors: bool = False,
4545
**overflow_kwargs: Overflow_In
46-
) -> FullParamReturn: ...
46+
) -> FullParamReturn[P, T]: ...
4747

4848

4949
def threaded(
50-
__function: Optional[TargetFunction] = None,
50+
__function: Optional[TargetFunction[P, T]] = None,
5151
*,
5252
args: Sequence[Data_In] = (),
5353
kwargs: Mapping[str, Data_In] = {},
5454
ignore_errors: Sequence[type[Exception]] = (),
5555
suppress_errors: bool = False,
5656
**overflow_kwargs: Overflow_In
57-
) -> Union[NoParamReturn, WithParamReturn, FullParamReturn]:
57+
) -> Union[NoParamReturn[P, T], WithParamReturn[P, T], FullParamReturn[P, T]]:
5858
"""
5959
Decorate a function to run it in a thread
6060
@@ -96,7 +96,7 @@ def threaded(
9696
"""
9797

9898
if not callable(__function):
99-
def wrapper(func: TargetFunction) -> FullParamReturn:
99+
def wrapper(func: TargetFunction[P, T]) -> FullParamReturn[P, T]:
100100
return threaded(
101101
func,
102102
args = args,
@@ -115,7 +115,7 @@ def wrapper(func: TargetFunction) -> FullParamReturn:
115115
kwargs = dict(kwargs)
116116

117117
@wraps(__function)
118-
def wrapped(*parsed_args: P.args, **parsed_kwargs: P.kwargs) -> Thread:
118+
def wrapped(*parsed_args: P.args, **parsed_kwargs: P.kwargs) -> Thread[P, T]:
119119
kwargs.update(parsed_kwargs)
120120

121121
processed_args = ( *args, *parsed_args )

src/thread/thread.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,22 @@ class ParallelProcessing: ...
1919
from .utils.config import Settings
2020
from .utils.algorithm import chunk_split
2121

22-
from ._types import ThreadStatus, Data_In, Data_Out, Overflow_In, TargetFunction, HookFunction
22+
from ._types import (
23+
ThreadStatus, Data_In, Data_Out, Overflow_In,
24+
TargetFunction, _Target_P, _Target_T,
25+
DatasetFunction, _Dataset_T,
26+
HookFunction
27+
)
28+
from typing_extensions import Generic, ParamSpec
2329
from typing import (
24-
Any, List,
25-
Callable, Optional,
30+
List,
31+
Callable, Optional, Union,
2632
Mapping, Sequence, Tuple
2733
)
2834

2935

3036
Threads: set['Thread'] = set()
31-
class Thread(threading.Thread):
37+
class Thread(threading.Thread, Generic[_Target_P, _Target_T]):
3238
"""
3339
Wraps python's `threading.Thread` class
3440
---------------------------------------
@@ -38,7 +44,7 @@ class Thread(threading.Thread):
3844

3945
status : ThreadStatus
4046
hooks : List[HookFunction]
41-
returned_value: Data_Out
47+
_returned_value: Data_Out
4248

4349
errors : List[Exception]
4450
ignore_errors : Sequence[type[Exception]]
@@ -51,7 +57,7 @@ class Thread(threading.Thread):
5157

5258
def __init__(
5359
self,
54-
target: TargetFunction,
60+
target: TargetFunction[_Target_P, _Target_T],
5561
args: Sequence[Data_In] = (),
5662
kwargs: Mapping[str, Data_In] = {},
5763
ignore_errors: Sequence[type[Exception]] = (),
@@ -80,7 +86,7 @@ def __init__(
8086
:param **: These are arguments parsed to `thread.Thread`
8187
"""
8288
_target = self._wrap_target(target)
83-
self.returned_value = None
89+
self._returned_value = None
8490
self.status = 'Idle'
8591
self.hooks = []
8692

@@ -100,17 +106,17 @@ def __init__(
100106
)
101107

102108

103-
def _wrap_target(self, target: TargetFunction) -> TargetFunction:
109+
def _wrap_target(self, target: TargetFunction[_Target_P, _Target_T]) -> TargetFunction[_Target_P, Union[_Target_T, None]]:
104110
"""Wraps the target function"""
105111
@wraps(target)
106-
def wrapper(*args: Any, **kwargs: Any) -> Any:
112+
def wrapper(*args: _Target_P.args, **kwargs: _Target_P.kwargs) -> Union[_Target_T, None]:
107113
self.status = 'Running'
108114

109115
global Threads
110116
Threads.add(self)
111117

112118
try:
113-
self.returned_value = target(*args, **kwargs)
119+
self._returned_value = target(*args, **kwargs)
114120
except Exception as e:
115121
if not any(isinstance(e, ignore) for ignore in self.ignore_errors):
116122
self.status = 'Errored'
@@ -129,7 +135,7 @@ def _invoke_hooks(self) -> None:
129135
errors: List[Tuple[Exception, str]] = []
130136
for hook in self.hooks:
131137
try:
132-
hook(self.returned_value)
138+
hook(self._returned_value)
133139
except Exception as e:
134140
if not any(isinstance(e, ignore) for ignore in self.ignore_errors):
135141
errors.append((
@@ -173,7 +179,7 @@ def _run_with_trace(self) -> None:
173179

174180

175181
@property
176-
def result(self) -> Data_Out:
182+
def result(self) -> _Target_T:
177183
"""
178184
The return value of the thread
179185
@@ -190,7 +196,7 @@ def result(self) -> Data_Out:
190196

191197
self._handle_exceptions()
192198
if self.status in ['Invoking hooks', 'Completed']:
193-
return self.returned_value
199+
return self._returned_value
194200
else:
195201
raise exceptions.ThreadStillRunningError()
196202

@@ -208,7 +214,7 @@ def is_alive(self) -> bool:
208214
return super().is_alive()
209215

210216

211-
def add_hook(self, hook: HookFunction) -> None:
217+
def add_hook(self, hook: HookFunction[_Target_T]) -> None:
212218
"""
213219
Adds a hook to the thread
214220
-------------------------
@@ -250,7 +256,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
250256
return not self.is_alive()
251257

252258

253-
def get_return_value(self) -> Data_Out:
259+
def get_return_value(self) -> _Target_T:
254260
"""
255261
Halts the current thread execution until the thread completes
256262
@@ -315,6 +321,7 @@ def start(self) -> None:
315321

316322

317323

324+
_P = ParamSpec('_P')
318325
class _ThreadWorker:
319326
progress: float
320327
thread: Thread
@@ -323,7 +330,7 @@ def __init__(self, thread: Thread, progress: float = 0) -> None:
323330
self.thread = thread
324331
self.progress = progress
325332

326-
class ParallelProcessing:
333+
class ParallelProcessing(Generic[_Target_P, _Target_T, _Dataset_T]):
327334
"""
328335
Multi-Threaded Parallel Processing
329336
---------------------------------------
@@ -335,7 +342,7 @@ class ParallelProcessing:
335342
_completed : int
336343

337344
status : ThreadStatus
338-
function : Callable[..., List[Data_Out]]
345+
function : TargetFunction
339346
dataset : Sequence[Data_In]
340347
max_threads : int
341348

@@ -344,8 +351,8 @@ class ParallelProcessing:
344351

345352
def __init__(
346353
self,
347-
function: TargetFunction,
348-
dataset: Sequence[Data_In],
354+
function: DatasetFunction[_Dataset_T, _Target_T],
355+
dataset: Sequence[_Dataset_T],
349356
max_threads: int = 8,
350357

351358
*overflow_args: Overflow_In,
@@ -386,9 +393,9 @@ def __init__(
386393
def _wrap_function(
387394
self,
388395
function: TargetFunction
389-
) -> Callable[..., List[Data_Out]]:
396+
) -> TargetFunction:
390397
@wraps(function)
391-
def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any) -> List[Data_Out]:
398+
def wrapper(index: int, data_chunk: Sequence[_Dataset_T], *args: _Target_P.args, **kwargs: _Target_P.kwargs) -> List[_Target_T]:
392399
computed: List[Data_Out] = []
393400
for i, data_entry in enumerate(data_chunk):
394401
v = function(data_entry, *args, **kwargs)
@@ -404,7 +411,7 @@ def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any
404411

405412

406413
@property
407-
def results(self) -> Data_Out:
414+
def results(self) -> List[_Dataset_T]:
408415
"""
409416
The return value of the threads if completed
410417
@@ -436,7 +443,7 @@ def is_alive(self) -> bool:
436443
return any(entry.thread.is_alive() for entry in self._threads)
437444

438445

439-
def get_return_values(self) -> List[Data_Out]:
446+
def get_return_values(self) -> List[_Dataset_T]:
440447
"""
441448
Halts the current thread execution until the thread completes
442449
@@ -506,6 +513,8 @@ def start(self) -> None:
506513
name_format = self.overflow_kwargs.get('name') and self.overflow_kwargs['name'] + '%s'
507514
self.overflow_kwargs = { i: v for i,v in self.overflow_kwargs.items() if i != 'name' and i != 'args' }
508515

516+
print(parsed_args, self.overflow_args)
517+
509518
for i, data_chunk in enumerate(chunk_split(self.dataset, max_threads)):
510519
chunk_thread = Thread(
511520
target = self.function,

tests/test_decorator.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,12 @@
11
import time
2-
import pytest
3-
from src.thread import threaded, exceptions
2+
from src.thread import threaded
43

54

65
# >>>>>>>>>> Dummy Functions <<<<<<<<<< #
76
def _dummy_target_raiseToPower(x: float, power: float, delay: float = 0):
87
time.sleep(delay)
98
return x**power
109

11-
def _dummy_raiseException(x: Exception, delay: float = 0):
12-
time.sleep(delay)
13-
raise x
14-
15-
def _dummy_iterative(itemCount: int, pTime: float = 0.1, delay: float = 0):
16-
time.sleep(delay)
17-
for i in range(itemCount):
18-
time.sleep(pTime)
19-
2010

2111

2212

0 commit comments

Comments
 (0)