@@ -19,16 +19,22 @@ class ParallelProcessing: ...
19
19
from .utils .config import Settings
20
20
from .utils .algorithm import chunk_split
21
21
22
- from ._types import ThreadStatus , Data_In , Data_Out , Overflow_In , TargetFunction , HookFunction
22
+ from ._types import (
23
+ ThreadStatus , Data_In , Data_Out , Overflow_In ,
24
+ TargetFunction , _Target_P , _Target_T ,
25
+ DatasetFunction , _Dataset_T ,
26
+ HookFunction
27
+ )
28
+ from typing_extensions import Generic , ParamSpec
23
29
from typing import (
24
- Any , List ,
25
- Callable , Optional ,
30
+ List ,
31
+ Callable , Optional , Union ,
26
32
Mapping , Sequence , Tuple
27
33
)
28
34
29
35
30
36
Threads : set ['Thread' ] = set ()
31
- class Thread (threading .Thread ):
37
+ class Thread (threading .Thread , Generic [ _Target_P , _Target_T ] ):
32
38
"""
33
39
Wraps python's `threading.Thread` class
34
40
---------------------------------------
@@ -38,7 +44,7 @@ class Thread(threading.Thread):
38
44
39
45
status : ThreadStatus
40
46
hooks : List [HookFunction ]
41
- returned_value : Data_Out
47
+ _returned_value : Data_Out
42
48
43
49
errors : List [Exception ]
44
50
ignore_errors : Sequence [type [Exception ]]
@@ -51,7 +57,7 @@ class Thread(threading.Thread):
51
57
52
58
def __init__ (
53
59
self ,
54
- target : TargetFunction ,
60
+ target : TargetFunction [ _Target_P , _Target_T ] ,
55
61
args : Sequence [Data_In ] = (),
56
62
kwargs : Mapping [str , Data_In ] = {},
57
63
ignore_errors : Sequence [type [Exception ]] = (),
@@ -80,7 +86,7 @@ def __init__(
80
86
:param **: These are arguments parsed to `thread.Thread`
81
87
"""
82
88
_target = self ._wrap_target (target )
83
- self .returned_value = None
89
+ self ._returned_value = None
84
90
self .status = 'Idle'
85
91
self .hooks = []
86
92
@@ -100,17 +106,17 @@ def __init__(
100
106
)
101
107
102
108
103
- def _wrap_target (self , target : TargetFunction ) -> TargetFunction :
109
+ def _wrap_target (self , target : TargetFunction [ _Target_P , _Target_T ] ) -> TargetFunction [ _Target_P , Union [ _Target_T , None ]] :
104
110
"""Wraps the target function"""
105
111
@wraps (target )
106
- def wrapper (* args : Any , ** kwargs : Any ) -> Any :
112
+ def wrapper (* args : _Target_P . args , ** kwargs : _Target_P . kwargs ) -> Union [ _Target_T , None ] :
107
113
self .status = 'Running'
108
114
109
115
global Threads
110
116
Threads .add (self )
111
117
112
118
try :
113
- self .returned_value = target (* args , ** kwargs )
119
+ self ._returned_value = target (* args , ** kwargs )
114
120
except Exception as e :
115
121
if not any (isinstance (e , ignore ) for ignore in self .ignore_errors ):
116
122
self .status = 'Errored'
@@ -129,7 +135,7 @@ def _invoke_hooks(self) -> None:
129
135
errors : List [Tuple [Exception , str ]] = []
130
136
for hook in self .hooks :
131
137
try :
132
- hook (self .returned_value )
138
+ hook (self ._returned_value )
133
139
except Exception as e :
134
140
if not any (isinstance (e , ignore ) for ignore in self .ignore_errors ):
135
141
errors .append ((
@@ -173,7 +179,7 @@ def _run_with_trace(self) -> None:
173
179
174
180
175
181
@property
176
- def result (self ) -> Data_Out :
182
+ def result (self ) -> _Target_T :
177
183
"""
178
184
The return value of the thread
179
185
@@ -190,7 +196,7 @@ def result(self) -> Data_Out:
190
196
191
197
self ._handle_exceptions ()
192
198
if self .status in ['Invoking hooks' , 'Completed' ]:
193
- return self .returned_value
199
+ return self ._returned_value
194
200
else :
195
201
raise exceptions .ThreadStillRunningError ()
196
202
@@ -208,7 +214,7 @@ def is_alive(self) -> bool:
208
214
return super ().is_alive ()
209
215
210
216
211
- def add_hook (self , hook : HookFunction ) -> None :
217
+ def add_hook (self , hook : HookFunction [ _Target_T ] ) -> None :
212
218
"""
213
219
Adds a hook to the thread
214
220
-------------------------
@@ -250,7 +256,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
250
256
return not self .is_alive ()
251
257
252
258
253
- def get_return_value (self ) -> Data_Out :
259
+ def get_return_value (self ) -> _Target_T :
254
260
"""
255
261
Halts the current thread execution until the thread completes
256
262
@@ -315,6 +321,7 @@ def start(self) -> None:
315
321
316
322
317
323
324
+ _P = ParamSpec ('_P' )
318
325
class _ThreadWorker :
319
326
progress : float
320
327
thread : Thread
@@ -323,7 +330,7 @@ def __init__(self, thread: Thread, progress: float = 0) -> None:
323
330
self .thread = thread
324
331
self .progress = progress
325
332
326
- class ParallelProcessing :
333
+ class ParallelProcessing ( Generic [ _Target_P , _Target_T , _Dataset_T ]) :
327
334
"""
328
335
Multi-Threaded Parallel Processing
329
336
---------------------------------------
@@ -335,7 +342,7 @@ class ParallelProcessing:
335
342
_completed : int
336
343
337
344
status : ThreadStatus
338
- function : Callable [..., List [ Data_Out ]]
345
+ function : TargetFunction
339
346
dataset : Sequence [Data_In ]
340
347
max_threads : int
341
348
@@ -344,8 +351,8 @@ class ParallelProcessing:
344
351
345
352
def __init__ (
346
353
self ,
347
- function : TargetFunction ,
348
- dataset : Sequence [Data_In ],
354
+ function : DatasetFunction [ _Dataset_T , _Target_T ] ,
355
+ dataset : Sequence [_Dataset_T ],
349
356
max_threads : int = 8 ,
350
357
351
358
* overflow_args : Overflow_In ,
@@ -386,9 +393,9 @@ def __init__(
386
393
def _wrap_function (
387
394
self ,
388
395
function : TargetFunction
389
- ) -> Callable [..., List [ Data_Out ]] :
396
+ ) -> TargetFunction :
390
397
@wraps (function )
391
- def wrapper (index : int , data_chunk : Sequence [Data_In ], * args : Any , ** kwargs : Any ) -> List [Data_Out ]:
398
+ def wrapper (index : int , data_chunk : Sequence [_Dataset_T ], * args : _Target_P . args , ** kwargs : _Target_P . kwargs ) -> List [_Target_T ]:
392
399
computed : List [Data_Out ] = []
393
400
for i , data_entry in enumerate (data_chunk ):
394
401
v = function (data_entry , * args , ** kwargs )
@@ -404,7 +411,7 @@ def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any
404
411
405
412
406
413
@property
407
- def results (self ) -> Data_Out :
414
+ def results (self ) -> List [ _Dataset_T ] :
408
415
"""
409
416
The return value of the threads if completed
410
417
@@ -436,7 +443,7 @@ def is_alive(self) -> bool:
436
443
return any (entry .thread .is_alive () for entry in self ._threads )
437
444
438
445
439
- def get_return_values (self ) -> List [Data_Out ]:
446
+ def get_return_values (self ) -> List [_Dataset_T ]:
440
447
"""
441
448
Halts the current thread execution until the thread completes
442
449
@@ -506,6 +513,8 @@ def start(self) -> None:
506
513
name_format = self .overflow_kwargs .get ('name' ) and self .overflow_kwargs ['name' ] + '%s'
507
514
self .overflow_kwargs = { i : v for i ,v in self .overflow_kwargs .items () if i != 'name' and i != 'args' }
508
515
516
+ print (parsed_args , self .overflow_args )
517
+
509
518
for i , data_chunk in enumerate (chunk_split (self .dataset , max_threads )):
510
519
chunk_thread = Thread (
511
520
target = self .function ,
0 commit comments