Skip to content

Commit a5bc6f9

Browse files
committed
Added generics support.
Signed-off-by: Pavel Kirilin <win10@list.ru>
1 parent e2479d2 commit a5bc6f9

File tree

6 files changed

+263
-13
lines changed

6 files changed

+263
-13
lines changed

.flake8

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ ignore =
9090
B008,
9191
; Found except `BaseException`
9292
WPS424,
93+
; Found a too complex `f` string
94+
WPS237,
9395

9496
; all init files
9597
__init__.py:

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: Testing
22

3-
on: push
3+
on: [push, pull_request]
44

55
jobs:
66
pre_job:

README.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,87 @@ with graph.sync_ctx(exception_propagation=False) as ctx:
176176

177177

178178
```
179+
180+
181+
## Generics support
182+
183+
We support generics substitution for class-based dependencies.
184+
For example, let's define an interface and a class. This class can be
185+
parameterized with some type and we consider this type a dependency.
186+
187+
```python
188+
import abc
189+
from typing import Any, Generic, TypeVar
190+
191+
class MyInterface(abc.ABC):
192+
@abc.abstractmethod
193+
def getval(self) -> Any:
194+
...
195+
196+
197+
_T = TypeVar("_T", bound=MyInterface)
198+
199+
200+
class MyClass(Generic[_T]):
201+
# We don't know exact type, but we assume
202+
# that it can be used as a dependency.
203+
def __init__(self, resource: _T = Depends()):
204+
self.resource = resource
205+
206+
@property
207+
def my_value(self) -> Any:
208+
return self.resource.getval()
209+
210+
```
211+
212+
Now let's create several implementation of defined interface:
213+
214+
```python
215+
216+
def getstr() -> str:
217+
return "strstr"
218+
219+
220+
def getint() -> int:
221+
return 100
222+
223+
224+
class MyDep1(MyInterface):
225+
def __init__(self, s: str = Depends(getstr)) -> None:
226+
self.s = s
227+
228+
def getval(self) -> str:
229+
return self.s
230+
231+
232+
class MyDep2(MyInterface):
233+
def __init__(self, i: int = Depends(getint)) -> None:
234+
self.i = i
235+
236+
def getval(self) -> int:
237+
return self.i
238+
239+
```
240+
241+
Now you can use these dependencies by just setting proper type hints.
242+
243+
```python
244+
def my_target(
245+
d1: MyClass[MyDep1] = Depends(),
246+
d2: MyClass[MyDep2] = Depends(),
247+
) -> None:
248+
print(d1.my_value)
249+
print(d2.my_value)
250+
251+
252+
with DependencyGraph(my_target).sync_ctx() as ctx:
253+
my_target(**ctx.resolve_kwargs())
254+
255+
```
256+
257+
This code will is going to print:
258+
259+
```
260+
strstr
261+
100
262+
```

taskiq_dependencies/dependency.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,22 @@ class Dependency:
102102
and calculate before execution.
103103
"""
104104

105-
def __init__( # noqa: WPS234
105+
def __init__( # noqa: WPS211, WPS234
106106
self,
107107
dependency: Optional[Union[Type[Any], Callable[..., Any]]] = None,
108108
*,
109109
use_cache: bool = True,
110110
kwargs: Optional[Dict[str, Any]] = None,
111111
signature: Optional[inspect.Parameter] = None,
112+
parent: "Optional[Dependency]" = None,
112113
) -> None:
113114
self._id = uuid.uuid4()
114115
self.dependency = dependency
115116
self.use_cache = use_cache
116117
self.param_name = ""
117118
self.kwargs = kwargs or {}
118119
self.signature = signature
120+
self.parent = parent
119121

120122
def __hash__(self) -> int:
121123
return hash(self._id)

taskiq_dependencies/graph.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
from collections import defaultdict, deque
33
from graphlib import TopologicalSorter
4-
from typing import Any, Callable, Dict, List, Optional, get_type_hints
4+
from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints
55

66
from taskiq_dependencies.ctx import AsyncResolveContext, SyncResolveContext
77
from taskiq_dependencies.dependency import Dependency
@@ -103,18 +103,55 @@ def _build_graph(self) -> None: # noqa: C901, WPS210
103103
if dep.dependency is None:
104104
continue
105105
# Get signature and type hints.
106-
sign = inspect.signature(dep.dependency)
107-
if inspect.isclass(dep.dependency):
106+
origin = getattr(dep.dependency, "__origin__", None)
107+
if origin is None:
108+
origin = dep.dependency
109+
110+
# If we found the typevar.
111+
# It means, that somebody depend on generic type.
112+
if isinstance(origin, TypeVar):
113+
if dep.parent is None:
114+
raise ValueError(f"Cannot resolve generic {dep.dependency}")
115+
parent_cls = dep.parent.dependency
116+
parent_cls_origin = getattr(parent_cls, "__origin__", None)
117+
# If we cannot find origin, than means, that we cannot resolve
118+
# generic parameters. So exiting.
119+
if parent_cls_origin is None:
120+
raise ValueError(
121+
f"Unknown generic argument {origin}. "
122+
+ f"Please provide a type in param `{dep.parent.param_name}`"
123+
+ f" of `{dep.parent.dependency}`",
124+
)
125+
# We zip together names of parameters and the subsctituted values
126+
# In parameters we would see TypeVars in args
127+
# we would find actual classes.
128+
generics = zip(
129+
parent_cls_origin.__parameters__,
130+
parent_cls.__args__, # type: ignore
131+
)
132+
for tvar, type_param in generics:
133+
# If we found the typevar we're currently try to resolve,
134+
# we need to find origin of the substituted class.
135+
if tvar == origin:
136+
dep.dependency = type_param
137+
origin = getattr(type_param, "__origin__", None)
138+
if origin is None:
139+
origin = type_param
140+
141+
if inspect.isclass(origin):
108142
# If this is a class, we need to get signature of
109143
# an __init__ method.
110-
hints = get_type_hints(dep.dependency.__init__) # noqa: WPS609
144+
hints = get_type_hints(origin.__init__) # noqa: WPS609
145+
sign = inspect.signature(origin.__init__) # noqa: WPS609
111146
elif inspect.isfunction(dep.dependency):
112147
# If this is function or an instance of a class, we get it's type hints.
113148
hints = get_type_hints(dep.dependency)
149+
sign = inspect.signature(origin) # type: ignore
114150
else:
115151
hints = get_type_hints(
116152
dep.dependency.__call__, # type: ignore # noqa: WPS609
117153
)
154+
sign = inspect.signature(origin) # type: ignore
118155

119156
# Now we need to iterate over parameters, to
120157
# find all parameters, that have TaskiqDepends as it's
@@ -172,6 +209,7 @@ def _build_graph(self) -> None: # noqa: C901, WPS210
172209
use_cache=default_value.use_cache,
173210
kwargs=default_value.kwargs,
174211
signature=param,
212+
parent=dep,
175213
)
176214
# Also we set the parameter name,
177215
# it will help us in future when

tests/test_graph.py

Lines changed: 131 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import uuid
3-
from typing import Any, AsyncGenerator, Generator
3+
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
44

55
import pytest
66

@@ -313,7 +313,6 @@ def target(class_val: str = Depends(TeClass("tval"))) -> None:
313313

314314

315315
def test_exception_generators() -> None:
316-
317316
errors_found = 0
318317

319318
def my_generator() -> Generator[int, None, None]:
@@ -335,7 +334,6 @@ def target(_: int = Depends(my_generator)) -> None:
335334

336335
@pytest.mark.anyio
337336
async def test_async_exception_generators() -> None:
338-
339337
errors_found = 0
340338

341339
async def my_generator() -> AsyncGenerator[int, None]:
@@ -357,7 +355,6 @@ def target(_: int = Depends(my_generator)) -> None:
357355

358356
@pytest.mark.anyio
359357
async def test_async_exception_generators_multiple() -> None:
360-
361358
errors_found = 0
362359

363360
async def my_generator() -> AsyncGenerator[int, None]:
@@ -383,7 +380,6 @@ def target(
383380

384381
@pytest.mark.anyio
385382
async def test_async_exception_in_teardown() -> None:
386-
387383
errors_found = 0
388384

389385
async def my_generator() -> AsyncGenerator[int, None]:
@@ -404,7 +400,6 @@ def target(_: int = Depends(my_generator)) -> None:
404400

405401
@pytest.mark.anyio
406402
async def test_async_propagation_disabled() -> None:
407-
408403
errors_found = 0
409404

410405
async def my_generator() -> AsyncGenerator[int, None]:
@@ -428,7 +423,6 @@ def target(_: int = Depends(my_generator)) -> None:
428423

429424

430425
def test_sync_propagation_disabled() -> None:
431-
432426
errors_found = 0
433427

434428
def my_generator() -> Generator[int, None, None]:
@@ -447,3 +441,133 @@ def target(_: int = Depends(my_generator)) -> None:
447441
target(**(g.resolve_kwargs()))
448442

449443
assert errors_found == 0
444+
445+
446+
def test_generic_classes() -> None:
447+
errors_found = 0
448+
449+
_T = TypeVar("_T")
450+
451+
class MyClass:
452+
pass
453+
454+
class MainClass(Generic[_T]):
455+
def __init__(self, val: _T = Depends()) -> None:
456+
self.val = val
457+
458+
def test_func(a: MainClass[MyClass] = Depends()) -> MyClass:
459+
return a.val
460+
461+
with DependencyGraph(target=test_func).sync_ctx(exception_propagation=False) as g:
462+
value = test_func(**(g.resolve_kwargs()))
463+
464+
assert errors_found == 0
465+
assert isinstance(value, MyClass)
466+
467+
468+
def test_generic_multiple() -> None:
469+
errors_found = 0
470+
471+
_T = TypeVar("_T")
472+
_V = TypeVar("_V")
473+
474+
class MyClass1:
475+
pass
476+
477+
class MyClass2:
478+
pass
479+
480+
class MainClass(Generic[_T, _V]):
481+
def __init__(self, t_val: _T = Depends(), v_val: _V = Depends()) -> None:
482+
self.t_val = t_val
483+
self.v_val = v_val
484+
485+
def test_func(
486+
a: MainClass[MyClass1, MyClass2] = Depends(),
487+
) -> MainClass[MyClass1, MyClass2]:
488+
return a
489+
490+
with DependencyGraph(target=test_func).sync_ctx(exception_propagation=False) as g:
491+
result = test_func(**(g.resolve_kwargs()))
492+
493+
assert errors_found == 0
494+
assert isinstance(result.t_val, MyClass1)
495+
assert isinstance(result.v_val, MyClass2)
496+
497+
498+
def test_generic_unordered() -> None:
499+
errors_found = 0
500+
501+
_T = TypeVar("_T")
502+
_V = TypeVar("_V")
503+
504+
class MyClass1:
505+
pass
506+
507+
class MyClass2:
508+
pass
509+
510+
class MainClass(Generic[_T, _V]):
511+
def __init__(self, v_val: _V = Depends(), t_val: _T = Depends()) -> None:
512+
self.t_val = t_val
513+
self.v_val = v_val
514+
515+
def test_func(
516+
a: MainClass[MyClass1, MyClass2] = Depends(),
517+
) -> MainClass[MyClass1, MyClass2]:
518+
return a
519+
520+
with DependencyGraph(target=test_func).sync_ctx(exception_propagation=False) as g:
521+
result = test_func(**(g.resolve_kwargs()))
522+
523+
assert errors_found == 0
524+
assert isinstance(result.t_val, MyClass1)
525+
assert isinstance(result.v_val, MyClass2)
526+
527+
528+
def test_generic_classes_nesting() -> None:
529+
errors_found = 0
530+
531+
_T = TypeVar("_T")
532+
_V = TypeVar("_V")
533+
534+
class DummyClass:
535+
pass
536+
537+
class DependantClass(Generic[_V]):
538+
def __init__(self, var: _V = Depends()) -> None:
539+
self.var = var
540+
541+
class MainClass(Generic[_T]):
542+
def __init__(self, var: _T = Depends()) -> None:
543+
self.var = var
544+
545+
def test_func(a: MainClass[DependantClass[DummyClass]] = Depends()) -> DummyClass:
546+
return a.var.var
547+
548+
with DependencyGraph(target=test_func).sync_ctx(exception_propagation=False) as g:
549+
value = test_func(**(g.resolve_kwargs()))
550+
551+
assert errors_found == 0
552+
assert isinstance(value, DummyClass)
553+
554+
555+
def test_generic_class_based_dependencies() -> None:
556+
"""Tests that if ParamInfo is used on the target, no error is raised."""
557+
558+
_T = TypeVar("_T")
559+
560+
class GenericClass(Generic[_T]):
561+
def __init__(self, class_val: _T = Depends()):
562+
self.return_val = class_val
563+
564+
def func_dep() -> GenericClass[int]:
565+
return GenericClass(123)
566+
567+
def target(my_dep: GenericClass[int] = Depends(func_dep)) -> int:
568+
return my_dep.return_val
569+
570+
with DependencyGraph(target=target).sync_ctx() as g:
571+
result = target(**g.resolve_kwargs())
572+
573+
assert result == 123

0 commit comments

Comments
 (0)