Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions injection/_core/common/asynchronous.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from abc import abstractmethod
from collections.abc import Awaitable, Callable, Generator
from dataclasses import dataclass
from typing import Any, NoReturn, Protocol, override, runtime_checkable
from typing import Any, NoReturn, Protocol, runtime_checkable


@dataclass(repr=False, eq=False, frozen=True, slots=True)
class SimpleAwaitable[T](Awaitable[T]):
callable: Callable[..., Awaitable[T]]

@override
def __await__(self) -> Generator[Any, Any, T]:
return self.callable().__await__()

Expand All @@ -30,11 +29,9 @@ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
class AsyncCaller[**P, T](Caller[P, T]):
callable: Callable[P, Awaitable[T]]

@override
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return await self.callable(*args, **kwargs)

@override
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> NoReturn:
raise RuntimeError(
"Synchronous call isn't supported for an asynchronous Callable."
Expand All @@ -45,10 +42,8 @@ def call(self, /, *args: P.args, **kwargs: P.kwargs) -> NoReturn:
class SyncCaller[**P, T](Caller[P, T]):
callable: Callable[P, T]

@override
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return self.callable(*args, **kwargs)

@override
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return self.callable(*args, **kwargs)
3 changes: 1 addition & 2 deletions injection/_core/common/invertible.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Protocol, override, runtime_checkable
from typing import Protocol, runtime_checkable


@runtime_checkable
Expand All @@ -15,6 +15,5 @@ def __invert__(self) -> T:
class SimpleInvertible[T](Invertible[T]):
callable: Callable[..., T]

@override
def __invert__(self) -> T:
return self.callable()
5 changes: 0 additions & 5 deletions injection/_core/common/lazy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Callable, Iterator, Mapping
from types import MappingProxyType
from typing import override

from injection._core.common.invertible import Invertible

Expand All @@ -14,7 +13,6 @@ class Lazy[T](Invertible[T]):
def __init__(self, factory: Callable[..., T]) -> None:
self.__setup_cache(factory)

@override
def __invert__(self) -> T:
return next(self.__iterator)

Expand Down Expand Up @@ -44,15 +42,12 @@ class LazyMapping[K, V](Mapping[K, V]):
def __init__(self, iterator: Iterator[tuple[K, V]]) -> None:
self.__lazy = Lazy(lambda: MappingProxyType(dict(iterator)))

@override
def __getitem__(self, key: K, /) -> V:
return (~self.__lazy)[key]

@override
def __iter__(self) -> Iterator[K]:
yield from ~self.__lazy

@override
def __len__(self) -> int:
return len(~self.__lazy)

Expand Down
10 changes: 1 addition & 9 deletions injection/_core/injectables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import MutableMapping
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, ClassVar, NoReturn, Protocol, override, runtime_checkable
from typing import Any, ClassVar, NoReturn, Protocol, runtime_checkable

from injection._core.common.asynchronous import Caller
from injection._core.common.threading import synchronized
Expand Down Expand Up @@ -37,11 +37,9 @@ class BaseInjectable[T](Injectable[T], ABC):
class SimpleInjectable[T](BaseInjectable[T]):
__slots__ = ()

@override
async def aget_instance(self) -> T:
return await self.factory.acall()

@override
def get_instance(self) -> T:
return self.factory.call()

Expand All @@ -56,15 +54,12 @@ def cache(self) -> MutableMapping[str, Any]:
return self.__dict__

@property
@override
def is_locked(self) -> bool:
return self.__key in self.cache

@override
def unlock(self) -> None:
self.cache.clear()

@override
async def aget_instance(self) -> T:
with suppress(KeyError):
return self.__check_instance()
Expand All @@ -75,7 +70,6 @@ async def aget_instance(self) -> T:

return instance

@override
def get_instance(self) -> T:
with suppress(KeyError):
return self.__check_instance()
Expand All @@ -97,10 +91,8 @@ def __set_instance(self, value: T) -> None:
class ShouldBeInjectable[T](Injectable[T]):
cls: type[T]

@override
async def aget_instance(self) -> T:
return self.get_instance()

@override
def get_instance(self) -> NoReturn:
raise InjectionError(f"`{self.cls}` should be an injectable.")
50 changes: 8 additions & 42 deletions injection/_core/module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import inspect
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import (
Expand All @@ -17,7 +16,8 @@
from dataclasses import dataclass, field
from enum import StrEnum
from functools import partialmethod, singledispatchmethod, update_wrapper
from inspect import Signature, isclass, iscoroutinefunction
from inspect import Signature, isclass, iscoroutinefunction, markcoroutinefunction
from inspect import signature as inspect_signature
from logging import Logger, getLogger
from queue import Empty, Queue
from types import MethodType
Expand All @@ -29,9 +29,7 @@
NamedTuple,
Protocol,
Self,
TypeGuard,
overload,
override,
runtime_checkable,
)
from uuid import uuid4
Expand Down Expand Up @@ -76,7 +74,6 @@ class LocatorDependenciesUpdated[T](LocatorEvent):
classes: Collection[InputType[T]]
mode: Mode

@override
def __str__(self) -> str:
length = len(self.classes)
formatted_types = ", ".join(f"`{cls}`" for cls in self.classes)
Expand All @@ -95,7 +92,6 @@ class ModuleEvent(Event, ABC):
class ModuleEventProxy(ModuleEvent):
event: Event

@override
def __str__(self) -> str:
return f"`{self.module}` has propagated an event: {self.origin}"

Expand All @@ -116,7 +112,6 @@ class ModuleAdded(ModuleEvent):
module_added: Module
priority: Priority

@override
def __str__(self) -> str:
return f"`{self.module}` now uses `{self.module_added}`."

Expand All @@ -125,7 +120,6 @@ def __str__(self) -> str:
class ModuleRemoved(ModuleEvent):
module_removed: Module

@override
def __str__(self) -> str:
return f"`{self.module}` no longer uses `{self.module_removed}`."

Expand All @@ -135,7 +129,6 @@ class ModulePriorityUpdated(ModuleEvent):
module_updated: Module
priority: Priority

@override
def __str__(self) -> str:
return (
f"In `{self.module}`, the priority `{self.priority}` "
Expand Down Expand Up @@ -242,7 +235,6 @@ class Locator(Broker):

static_hooks: ClassVar[LocatorHooks[Any]] = LocatorHooks.default()

@override
def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
for input_class in self.__standardize_inputs((cls,)):
try:
Expand All @@ -254,15 +246,13 @@ def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:

raise NoInjectable(cls)

@override
def __contains__(self, cls: InputType[Any], /) -> bool:
return any(
input_class in self.__records
for input_class in self.__standardize_inputs((cls,))
)

@property
@override
def is_locked(self) -> bool:
return any(injectable.is_locked for injectable in self.__injectables)

Expand All @@ -284,15 +274,13 @@ def update[T](self, updater: Updater[T]) -> Self:

return self

@override
@synchronized()
def unlock(self) -> Self:
for injectable in self.__injectables:
injectable.unlock()

return self

@override
async def all_ready(self) -> None:
for injectable in self.__injectables:
await injectable.aget_instance()
Expand Down Expand Up @@ -387,20 +375,17 @@ class Module(Broker, EventListener):
def __post_init__(self) -> None:
self.__locator.add_listener(self)

@override
def __getitem__[T](self, cls: InputType[T], /) -> Injectable[T]:
for broker in self.__brokers:
with suppress(KeyError):
return broker[cls]

raise NoInjectable(cls)

@override
def __contains__(self, cls: InputType[Any], /) -> bool:
return any(cls in broker for broker in self.__brokers)

@property
@override
def is_locked(self) -> bool:
return any(broker.is_locked for broker in self.__brokers)

Expand Down Expand Up @@ -695,15 +680,13 @@ def change_priority(self, module: Module, priority: Priority | PriorityStr) -> S

return self

@override
@synchronized()
def unlock(self) -> Self:
for broker in self.__brokers:
broker.unlock()

return self

@override
async def all_ready(self) -> None:
for broker in self.__brokers:
await broker.all_ready()
Expand All @@ -720,7 +703,6 @@ def remove_listener(self, listener: EventListener) -> Self:
self.__channel.remove_listener(listener)
return self

@override
def on_event(self, event: Event, /) -> ContextManager[None] | None:
self_event = ModuleEventProxy(self, event)
return self.dispatch(self_event)
Expand Down Expand Up @@ -890,7 +872,7 @@ def signature(self) -> Signature:
return self.__signature

with synchronized():
signature = inspect.signature(self.wrapped, eval_str=True)
signature = inspect_signature(self.wrapped, eval_str=True)
self.__signature = signature

return signature
Expand All @@ -915,13 +897,11 @@ def bind(
additional_arguments = self.__dependencies.get_arguments()
return self.__bind(args, kwargs, additional_arguments)

@override
async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
self.__setup()
arguments = await self.abind(args, kwargs)
return self.wrapped(*arguments.args, **arguments.kwargs)

@override
def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
self.__setup()
arguments = self.bind(args, kwargs)
Expand Down Expand Up @@ -957,7 +937,6 @@ def decorator(wp: Callable[_P, _T]) -> Callable[_P, _T]:
return decorator(wrapped) if wrapped else decorator

@singledispatchmethod
@override
def on_event(self, event: Event, /) -> ContextManager[None] | None: # type: ignore[override]
return None

Expand Down Expand Up @@ -1014,11 +993,9 @@ def __init__(self, metadata: InjectMetadata[P, T]) -> None:
update_wrapper(self, metadata.wrapped)
self.__inject_metadata__ = metadata

@override
def __repr__(self) -> str: # pragma: no cover
return repr(self.__inject_metadata__.wrapped)

@override
def __str__(self) -> str: # pragma: no cover
return str(self.__inject_metadata__.wrapped)

Expand All @@ -1043,34 +1020,23 @@ def __set_name__(self, owner: type, name: str) -> None:
class AsyncInjectedFunction[**P, T](InjectedFunction[P, Awaitable[T]]):
__slots__ = ()

@override
def __init__(self, metadata: InjectMetadata[P, Awaitable[T]]) -> None:
super().__init__(metadata)
markcoroutinefunction(self)

async def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return await (await self.__inject_metadata__.acall(*args, **kwargs))


class SyncInjectedFunction[**P, T](InjectedFunction[P, T]):
__slots__ = ()

@override
def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return self.__inject_metadata__.call(*args, **kwargs)


def _is_coroutine_function[**P, T](
function: Callable[P, T] | Callable[P, Awaitable[T]],
) -> TypeGuard[Callable[P, Awaitable[T]]]:
if iscoroutinefunction(function):
return True

elif isclass(function):
return False

call = getattr(function, "__call__", None)
return iscoroutinefunction(call)


def _get_caller[**P, T](function: Callable[P, T]) -> Caller[P, T]:
if _is_coroutine_function(function):
if iscoroutinefunction(function):
return AsyncCaller(function)

elif isinstance(function, InjectedFunction):
Expand Down
Loading