Skip to content

Threaded Decorator Type-Safety #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/thread/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from typing import Any, Literal, Callable, Union
from typing_extensions import ParamSpec, TypeVar


# Descriptive Types
Expand All @@ -27,5 +28,11 @@


# Function types
HookFunction = Callable[[Data_Out], Union[Any, None]]
TargetFunction = Callable[..., Data_Out]
_Target_P = ParamSpec('_Target_P')
_Target_T = TypeVar('_Target_T')
TargetFunction = Callable[_Target_P, _Target_T]

HookFunction = Callable[[_Target_T], Union[Any, None]]

_Dataset_T = TypeVar('_Dataset_T')
DatasetFunction = Callable[[_Dataset_T], _Target_T]
26 changes: 13 additions & 13 deletions src/thread/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

T = TypeVar('T')
P = ParamSpec('P')
TargetFunction = Callable[P, Data_Out]
NoParamReturn = Callable[P, Thread]
WithParamReturn = Callable[[TargetFunction], NoParamReturn]
FullParamReturn = Callable[P, Thread]
WrappedWithParamReturn = Callable[[TargetFunction], WithParamReturn]
TargetFunction = Callable[P, T]
NoParamReturn = Callable[P, Thread[P, T]]
WithParamReturn = Callable[[TargetFunction[P, T]], NoParamReturn[P, T]]
FullParamReturn = Callable[P, Thread[P, T]]
WrappedWithParamReturn = Callable[[TargetFunction[P, T]], WithParamReturn[P, T]]


@overload
def threaded(__function: TargetFunction) -> NoParamReturn: ...
def threaded(__function: TargetFunction[P, T]) -> NoParamReturn[P, T]: ...

@overload
def threaded(
Expand All @@ -32,29 +32,29 @@ def threaded(
ignore_errors: Sequence[type[Exception]] = (),
suppress_errors: bool = False,
**overflow_kwargs: Overflow_In
) -> WithParamReturn: ...
) -> WithParamReturn[P, T]: ...

@overload
def threaded(
__function: Callable[P, Data_Out],
__function: TargetFunction[P, T],
*,
args: Sequence[Data_In] = (),
kwargs: Mapping[str, Data_In] = {},
ignore_errors: Sequence[type[Exception]] = (),
suppress_errors: bool = False,
**overflow_kwargs: Overflow_In
) -> FullParamReturn: ...
) -> FullParamReturn[P, T]: ...


def threaded(
__function: Optional[TargetFunction] = None,
__function: Optional[TargetFunction[P, T]] = None,
*,
args: Sequence[Data_In] = (),
kwargs: Mapping[str, Data_In] = {},
ignore_errors: Sequence[type[Exception]] = (),
suppress_errors: bool = False,
**overflow_kwargs: Overflow_In
) -> Union[NoParamReturn, WithParamReturn, FullParamReturn]:
) -> Union[NoParamReturn[P, T], WithParamReturn[P, T], FullParamReturn[P, T]]:
"""
Decorate a function to run it in a thread

Expand Down Expand Up @@ -96,7 +96,7 @@ def threaded(
"""

if not callable(__function):
def wrapper(func: TargetFunction) -> FullParamReturn:
def wrapper(func: TargetFunction[P, T]) -> FullParamReturn[P, T]:
return threaded(
func,
args = args,
Expand All @@ -115,7 +115,7 @@ def wrapper(func: TargetFunction) -> FullParamReturn:
kwargs = dict(kwargs)

@wraps(__function)
def wrapped(*parsed_args: P.args, **parsed_kwargs: P.kwargs) -> Thread:
def wrapped(*parsed_args: P.args, **parsed_kwargs: P.kwargs) -> Thread[P, T]:
kwargs.update(parsed_kwargs)

processed_args = ( *args, *parsed_args )
Expand Down
55 changes: 32 additions & 23 deletions src/thread/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,22 @@ class ParallelProcessing: ...
from .utils.config import Settings
from .utils.algorithm import chunk_split

from ._types import ThreadStatus, Data_In, Data_Out, Overflow_In, TargetFunction, HookFunction
from ._types import (
ThreadStatus, Data_In, Data_Out, Overflow_In,
TargetFunction, _Target_P, _Target_T,
DatasetFunction, _Dataset_T,
HookFunction
)
from typing_extensions import Generic, ParamSpec
from typing import (
Any, List,
Callable, Optional,
List,
Callable, Optional, Union,
Mapping, Sequence, Tuple
)


Threads: set['Thread'] = set()
class Thread(threading.Thread):
class Thread(threading.Thread, Generic[_Target_P, _Target_T]):
"""
Wraps python's `threading.Thread` class
---------------------------------------
Expand All @@ -38,7 +44,7 @@ class Thread(threading.Thread):

status : ThreadStatus
hooks : List[HookFunction]
returned_value: Data_Out
_returned_value: Data_Out

errors : List[Exception]
ignore_errors : Sequence[type[Exception]]
Expand All @@ -51,7 +57,7 @@ class Thread(threading.Thread):

def __init__(
self,
target: TargetFunction,
target: TargetFunction[_Target_P, _Target_T],
args: Sequence[Data_In] = (),
kwargs: Mapping[str, Data_In] = {},
ignore_errors: Sequence[type[Exception]] = (),
Expand Down Expand Up @@ -80,7 +86,7 @@ def __init__(
:param **: These are arguments parsed to `thread.Thread`
"""
_target = self._wrap_target(target)
self.returned_value = None
self._returned_value = None
self.status = 'Idle'
self.hooks = []

Expand All @@ -100,17 +106,17 @@ def __init__(
)


def _wrap_target(self, target: TargetFunction) -> TargetFunction:
def _wrap_target(self, target: TargetFunction[_Target_P, _Target_T]) -> TargetFunction[_Target_P, Union[_Target_T, None]]:
"""Wraps the target function"""
@wraps(target)
def wrapper(*args: Any, **kwargs: Any) -> Any:
def wrapper(*args: _Target_P.args, **kwargs: _Target_P.kwargs) -> Union[_Target_T, None]:
self.status = 'Running'

global Threads
Threads.add(self)

try:
self.returned_value = target(*args, **kwargs)
self._returned_value = target(*args, **kwargs)
except Exception as e:
if not any(isinstance(e, ignore) for ignore in self.ignore_errors):
self.status = 'Errored'
Expand All @@ -129,7 +135,7 @@ def _invoke_hooks(self) -> None:
errors: List[Tuple[Exception, str]] = []
for hook in self.hooks:
try:
hook(self.returned_value)
hook(self._returned_value)
except Exception as e:
if not any(isinstance(e, ignore) for ignore in self.ignore_errors):
errors.append((
Expand Down Expand Up @@ -173,7 +179,7 @@ def _run_with_trace(self) -> None:


@property
def result(self) -> Data_Out:
def result(self) -> _Target_T:
"""
The return value of the thread

Expand All @@ -190,7 +196,7 @@ def result(self) -> Data_Out:

self._handle_exceptions()
if self.status in ['Invoking hooks', 'Completed']:
return self.returned_value
return self._returned_value
else:
raise exceptions.ThreadStillRunningError()

Expand All @@ -208,7 +214,7 @@ def is_alive(self) -> bool:
return super().is_alive()


def add_hook(self, hook: HookFunction) -> None:
def add_hook(self, hook: HookFunction[_Target_T]) -> None:
"""
Adds a hook to the thread
-------------------------
Expand Down Expand Up @@ -250,7 +256,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
return not self.is_alive()


def get_return_value(self) -> Data_Out:
def get_return_value(self) -> _Target_T:
"""
Halts the current thread execution until the thread completes

Expand Down Expand Up @@ -315,6 +321,7 @@ def start(self) -> None:



_P = ParamSpec('_P')
class _ThreadWorker:
progress: float
thread: Thread
Expand All @@ -323,7 +330,7 @@ def __init__(self, thread: Thread, progress: float = 0) -> None:
self.thread = thread
self.progress = progress

class ParallelProcessing:
class ParallelProcessing(Generic[_Target_P, _Target_T, _Dataset_T]):
"""
Multi-Threaded Parallel Processing
---------------------------------------
Expand All @@ -335,7 +342,7 @@ class ParallelProcessing:
_completed : int

status : ThreadStatus
function : Callable[..., List[Data_Out]]
function : TargetFunction
dataset : Sequence[Data_In]
max_threads : int

Expand All @@ -344,8 +351,8 @@ class ParallelProcessing:

def __init__(
self,
function: TargetFunction,
dataset: Sequence[Data_In],
function: DatasetFunction[_Dataset_T, _Target_T],
dataset: Sequence[_Dataset_T],
max_threads: int = 8,

*overflow_args: Overflow_In,
Expand Down Expand Up @@ -386,9 +393,9 @@ def __init__(
def _wrap_function(
self,
function: TargetFunction
) -> Callable[..., List[Data_Out]]:
) -> TargetFunction:
@wraps(function)
def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any) -> List[Data_Out]:
def wrapper(index: int, data_chunk: Sequence[_Dataset_T], *args: _Target_P.args, **kwargs: _Target_P.kwargs) -> List[_Target_T]:
computed: List[Data_Out] = []
for i, data_entry in enumerate(data_chunk):
v = function(data_entry, *args, **kwargs)
Expand All @@ -404,7 +411,7 @@ def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any


@property
def results(self) -> Data_Out:
def results(self) -> List[_Dataset_T]:
"""
The return value of the threads if completed

Expand Down Expand Up @@ -436,7 +443,7 @@ def is_alive(self) -> bool:
return any(entry.thread.is_alive() for entry in self._threads)


def get_return_values(self) -> List[Data_Out]:
def get_return_values(self) -> List[_Dataset_T]:
"""
Halts the current thread execution until the thread completes

Expand Down Expand Up @@ -506,6 +513,8 @@ def start(self) -> None:
name_format = self.overflow_kwargs.get('name') and self.overflow_kwargs['name'] + '%s'
self.overflow_kwargs = { i: v for i,v in self.overflow_kwargs.items() if i != 'name' and i != 'args' }

print(parsed_args, self.overflow_args)

for i, data_chunk in enumerate(chunk_split(self.dataset, max_threads)):
chunk_thread = Thread(
target = self.function,
Expand Down
12 changes: 1 addition & 11 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
import time
import pytest
from src.thread import threaded, exceptions
from src.thread import threaded


# >>>>>>>>>> Dummy Functions <<<<<<<<<< #
def _dummy_target_raiseToPower(x: float, power: float, delay: float = 0):
time.sleep(delay)
return x**power

def _dummy_raiseException(x: Exception, delay: float = 0):
time.sleep(delay)
raise x

def _dummy_iterative(itemCount: int, pTime: float = 0.1, delay: float = 0):
time.sleep(delay)
for i in range(itemCount):
time.sleep(pTime)




Expand Down