Skip to content

Commit 1e2de51

Browse files
committed
Add context manager support to Resource provider
1 parent 4b3476c commit 1e2de51

File tree

4 files changed

+208
-185
lines changed

4 files changed

+208
-185
lines changed

src/dependency_injector/providers.pyx

Lines changed: 84 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@ import re
1515
import sys
1616
import threading
1717
import warnings
18+
from asyncio import ensure_future
1819
from configparser import ConfigParser as IniConfigParser
20+
from contextlib import asynccontextmanager, contextmanager
1921
from contextvars import ContextVar
22+
from inspect import isasyncgenfunction, isgeneratorfunction
2023

2124
try:
2225
from inspect import _is_coroutine_mark as _is_coroutine_marker
@@ -3598,6 +3601,17 @@ cdef class Dict(Provider):
35983601
return __provide_keyword_args(kwargs, self._kwargs, self._kwargs_len, self._async_mode)
35993602

36003603

3604+
@cython.no_gc
3605+
cdef class NullAwaitable:
3606+
def __next__(self):
3607+
raise StopIteration from None
3608+
3609+
def __await__(self):
3610+
return self
3611+
3612+
3613+
cdef NullAwaitable NULL_AWAITABLE = NullAwaitable()
3614+
36013615

36023616
cdef class Resource(Provider):
36033617
"""Resource provider provides a component with initialization and shutdown."""
@@ -3653,6 +3667,12 @@ cdef class Resource(Provider):
36533667
def set_provides(self, provides):
36543668
"""Set provider provides."""
36553669
provides = _resolve_string_import(provides)
3670+
3671+
if isasyncgenfunction(provides):
3672+
provides = asynccontextmanager(provides)
3673+
elif isgeneratorfunction(provides):
3674+
provides = contextmanager(provides)
3675+
36563676
self._provides = provides
36573677
return self
36583678

@@ -3753,28 +3773,21 @@ cdef class Resource(Provider):
37533773
"""Shutdown resource."""
37543774
if not self._initialized:
37553775
if self._async_mode == ASYNC_MODE_ENABLED:
3756-
result = asyncio.Future()
3757-
result.set_result(None)
3758-
return result
3776+
return NULL_AWAITABLE
37593777
return
37603778

37613779
if self._shutdowner:
3762-
try:
3763-
shutdown = self._shutdowner(self._resource)
3764-
except StopIteration:
3765-
pass
3766-
else:
3767-
if inspect.isawaitable(shutdown):
3768-
return self._create_shutdown_future(shutdown)
3780+
future = self._shutdowner(None, None, None)
3781+
3782+
if __is_future_or_coroutine(future):
3783+
return ensure_future(self._shutdown_async(future))
37693784

37703785
self._resource = None
37713786
self._initialized = False
37723787
self._shutdowner = None
37733788

37743789
if self._async_mode == ASYNC_MODE_ENABLED:
3775-
result = asyncio.Future()
3776-
result.set_result(None)
3777-
return result
3790+
return NULL_AWAITABLE
37783791

37793792
@property
37803793
def related(self):
@@ -3784,165 +3797,75 @@ cdef class Resource(Provider):
37843797
yield from filter(is_provider, self.kwargs.values())
37853798
yield from super().related
37863799

3800+
async def _shutdown_async(self, future) -> None:
3801+
try:
3802+
await future
3803+
finally:
3804+
self._resource = None
3805+
self._initialized = False
3806+
self._shutdowner = None
3807+
3808+
async def _handle_async_cm(self, obj) -> None:
3809+
try:
3810+
self._resource = resource = await obj.__aenter__()
3811+
self._shutdowner = obj.__aexit__
3812+
return resource
3813+
except:
3814+
self._initialized = False
3815+
raise
3816+
3817+
async def _provide_async(self, future) -> None:
3818+
try:
3819+
obj = await future
3820+
3821+
if hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
3822+
self._resource = await obj.__aenter__()
3823+
self._shutdowner = obj.__aexit__
3824+
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
3825+
self._resource = obj.__enter__()
3826+
self._shutdowner = obj.__exit__
3827+
else:
3828+
self._resource = obj
3829+
self._shutdowner = None
3830+
3831+
return self._resource
3832+
except:
3833+
self._initialized = False
3834+
raise
3835+
37873836
cpdef object _provide(self, tuple args, dict kwargs):
37883837
if self._initialized:
37893838
return self._resource
37903839

3791-
if self._is_resource_subclass(self._provides):
3792-
initializer = self._provides()
3793-
self._resource = __call(
3794-
initializer.init,
3795-
args,
3796-
self._args,
3797-
self._args_len,
3798-
kwargs,
3799-
self._kwargs,
3800-
self._kwargs_len,
3801-
self._async_mode,
3802-
)
3803-
self._shutdowner = initializer.shutdown
3804-
elif self._is_async_resource_subclass(self._provides):
3805-
initializer = self._provides()
3806-
async_init = __call(
3807-
initializer.init,
3808-
args,
3809-
self._args,
3810-
self._args_len,
3811-
kwargs,
3812-
self._kwargs,
3813-
self._kwargs_len,
3814-
self._async_mode,
3815-
)
3816-
self._initialized = True
3817-
return self._create_init_future(async_init, initializer.shutdown)
3818-
elif inspect.isgeneratorfunction(self._provides):
3819-
initializer = __call(
3820-
self._provides,
3821-
args,
3822-
self._args,
3823-
self._args_len,
3824-
kwargs,
3825-
self._kwargs,
3826-
self._kwargs_len,
3827-
self._async_mode,
3828-
)
3829-
self._resource = next(initializer)
3830-
self._shutdowner = initializer.send
3831-
elif iscoroutinefunction(self._provides):
3832-
initializer = __call(
3833-
self._provides,
3834-
args,
3835-
self._args,
3836-
self._args_len,
3837-
kwargs,
3838-
self._kwargs,
3839-
self._kwargs_len,
3840-
self._async_mode,
3841-
)
3840+
obj = __call(
3841+
self._provides,
3842+
args,
3843+
self._args,
3844+
self._args_len,
3845+
kwargs,
3846+
self._kwargs,
3847+
self._kwargs_len,
3848+
self._async_mode,
3849+
)
3850+
3851+
if __is_future_or_coroutine(obj):
38423852
self._initialized = True
3843-
return self._create_init_future(initializer)
3844-
elif isasyncgenfunction(self._provides):
3845-
initializer = __call(
3846-
self._provides,
3847-
args,
3848-
self._args,
3849-
self._args_len,
3850-
kwargs,
3851-
self._kwargs,
3852-
self._kwargs_len,
3853-
self._async_mode,
3854-
)
3853+
self._resource = resource = ensure_future(self._provide_async(obj))
3854+
return resource
3855+
elif hasattr(obj, '__enter__') and hasattr(obj, '__exit__'):
3856+
self._resource = obj.__enter__()
3857+
self._shutdowner = obj.__exit__
3858+
elif hasattr(obj, '__aenter__') and hasattr(obj, '__aexit__'):
38553859
self._initialized = True
3856-
return self._create_async_gen_init_future(initializer)
3857-
elif callable(self._provides):
3858-
self._resource = __call(
3859-
self._provides,
3860-
args,
3861-
self._args,
3862-
self._args_len,
3863-
kwargs,
3864-
self._kwargs,
3865-
self._kwargs_len,
3866-
self._async_mode,
3867-
)
3860+
self._resource = resource = ensure_future(self._handle_async_cm(obj))
3861+
return resource
38683862
else:
3869-
raise Error("Unknown type of resource initializer")
3863+
self._resource = obj
3864+
self._shutdowner = None
38703865

38713866
self._initialized = True
38723867
return self._resource
38733868

3874-
def _create_init_future(self, future, shutdowner=None):
3875-
callback = self._async_init_callback
3876-
if shutdowner:
3877-
callback = functools.partial(callback, shutdowner=shutdowner)
3878-
3879-
future = asyncio.ensure_future(future)
3880-
future.add_done_callback(callback)
3881-
self._resource = future
3882-
3883-
return future
3884-
3885-
def _create_async_gen_init_future(self, initializer):
3886-
if inspect.isasyncgen(initializer):
3887-
return self._create_init_future(initializer.__anext__(), initializer.asend)
3888-
3889-
future = asyncio.Future()
3890-
3891-
create_initializer = asyncio.ensure_future(initializer)
3892-
create_initializer.add_done_callback(functools.partial(self._async_create_gen_callback, future))
3893-
self._resource = future
3894-
3895-
return future
3896-
3897-
def _async_init_callback(self, initializer, shutdowner=None):
3898-
try:
3899-
resource = initializer.result()
3900-
except Exception:
3901-
self._initialized = False
3902-
else:
3903-
self._resource = resource
3904-
self._shutdowner = shutdowner
3905-
3906-
def _async_create_gen_callback(self, future, initializer_future):
3907-
initializer = initializer_future.result()
3908-
init_future = self._create_init_future(initializer.__anext__(), initializer.asend)
3909-
init_future.add_done_callback(functools.partial(self._async_trigger_result, future))
3910-
3911-
def _async_trigger_result(self, future, future_result):
3912-
future.set_result(future_result.result())
3913-
3914-
def _create_shutdown_future(self, shutdown_future):
3915-
future = asyncio.Future()
3916-
shutdown_future = asyncio.ensure_future(shutdown_future)
3917-
shutdown_future.add_done_callback(functools.partial(self._async_shutdown_callback, future))
3918-
return future
3919-
3920-
def _async_shutdown_callback(self, future_result, shutdowner):
3921-
try:
3922-
shutdowner.result()
3923-
except StopAsyncIteration:
3924-
pass
3925-
3926-
self._resource = None
3927-
self._initialized = False
3928-
self._shutdowner = None
3929-
3930-
future_result.set_result(None)
3931-
3932-
@staticmethod
3933-
def _is_resource_subclass(instance):
3934-
if not isinstance(instance, type):
3935-
return
3936-
from . import resources
3937-
return issubclass(instance, resources.Resource)
3938-
3939-
@staticmethod
3940-
def _is_async_resource_subclass(instance):
3941-
if not isinstance(instance, type):
3942-
return
3943-
from . import resources
3944-
return issubclass(instance, resources.AsyncResource)
3945-
39463869

39473870
cdef class Container(Provider):
39483871
"""Container provider provides an instance of declarative container.
@@ -4993,14 +4916,6 @@ def iscoroutinefunction(obj):
49934916
return False
49944917

49954918

4996-
def isasyncgenfunction(obj):
4997-
"""Check if object is an asynchronous generator function."""
4998-
try:
4999-
return inspect.isasyncgenfunction(obj)
5000-
except AttributeError:
5001-
return False
5002-
5003-
50044919
def _resolve_string_import(provides):
50054920
if provides is None:
50064921
return provides

src/dependency_injector/resources.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,54 @@
11
"""Resources module."""
22

3-
import abc
4-
from typing import TypeVar, Generic, Optional
5-
3+
from abc import ABCMeta, abstractmethod
4+
from typing import Any, ClassVar, Generic, Optional, Tuple, TypeVar
65

76
T = TypeVar("T")
87

98

10-
class Resource(Generic[T], metaclass=abc.ABCMeta):
9+
class Resource(Generic[T], metaclass=ABCMeta):
10+
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
11+
12+
obj: Optional[T]
13+
14+
def __init__(self, *args: Any, **kwargs: Any) -> None:
15+
self.args = args
16+
self.kwargs = kwargs
17+
self.obj = None
1118

12-
@abc.abstractmethod
13-
def init(self, *args, **kwargs) -> Optional[T]: ...
19+
@abstractmethod
20+
def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
1421

1522
def shutdown(self, resource: Optional[T]) -> None: ...
1623

24+
def __enter__(self) -> Optional[T]:
25+
self.obj = obj = self.init(*self.args, **self.kwargs)
26+
return obj
27+
28+
def __exit__(self, *exc_info: Any) -> None:
29+
self.shutdown(self.obj)
30+
self.obj = None
31+
1732

18-
class AsyncResource(Generic[T], metaclass=abc.ABCMeta):
33+
class AsyncResource(Generic[T], metaclass=ABCMeta):
34+
__slots__: ClassVar[Tuple[str, ...]] = ("args", "kwargs", "obj")
1935

20-
@abc.abstractmethod
21-
async def init(self, *args, **kwargs) -> Optional[T]: ...
36+
obj: Optional[T]
37+
38+
def __init__(self, *args: Any, **kwargs: Any) -> None:
39+
self.args = args
40+
self.kwargs = kwargs
41+
self.obj = None
42+
43+
@abstractmethod
44+
async def init(self, *args: Any, **kwargs: Any) -> Optional[T]: ...
2245

2346
async def shutdown(self, resource: Optional[T]) -> None: ...
47+
48+
async def __aenter__(self) -> Optional[T]:
49+
self.obj = obj = await self.init(*self.args, **self.kwargs)
50+
return obj
51+
52+
async def __aexit__(self, *exc_info: Any) -> None:
53+
await self.shutdown(self.obj)
54+
self.obj = None

0 commit comments

Comments
 (0)