Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions cq/_core/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from collections.abc import Awaitable, Callable, Iterator
from dataclasses import dataclass, field
from functools import partial
from inspect import Parameter, getmro, isclass
from inspect import Parameter, isclass
from inspect import signature as inspect_signature
from typing import TYPE_CHECKING, Any, Protocol, Self, overload, runtime_checkable

import injection
from type_analyzer import MatchingTypesConfig, iter_matching_types, matching_types

type HandlerType[**P, T] = type[Handler[P, T]]
type HandlerFactory[**P, T] = Callable[..., Awaitable[Handler[P, T]]]
Expand Down Expand Up @@ -49,12 +50,14 @@ def handlers_from(
self,
input_type: type[I],
) -> Iterator[Callable[[I], Awaitable[O]]]:
for it in getmro(input_type):
for factory in self.__factories.get(it, ()):
for key_type in _iter_key_types(input_type):
for factory in self.__factories.get(key_type, ()):
yield _make_handle_function(factory)

def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
self.__factories[input_type].append(factory)
for key_type in _build_key_types(input_type):
self.__factories[key_type].append(factory)

return self


Expand All @@ -69,18 +72,20 @@ def handlers_from(
self,
input_type: type[I],
) -> Iterator[Callable[[I], Awaitable[O]]]:
for it in getmro(input_type):
factory = self.__factories.get(it, None)
for key_type in _iter_key_types(input_type):
factory = self.__factories.get(key_type, None)
if factory is not None:
yield _make_handle_function(factory)

def subscribe(self, input_type: type[I], factory: HandlerFactory[[I], O]) -> Self:
if input_type in self.__factories:
raise RuntimeError(
f"A handler is already registered for the input type: `{input_type}`."
)
for key_type in _build_key_types(input_type):
if key_type in self.__factories:
raise RuntimeError(
f"A handler is already registered for the input type: `{key_type}`."
)

self.__factories[key_type] = factory

self.__factories[input_type] = factory
return self


Expand Down Expand Up @@ -152,6 +157,20 @@ def __decorator(
return wrapped


def _build_key_types(input_type: Any) -> tuple[Any, ...]:
config = MatchingTypesConfig(ignore_none=True)
return matching_types(input_type, config)


def _iter_key_types(input_type: Any) -> Iterator[Any]:
config = MatchingTypesConfig(
with_bases=True,
with_origin=True,
with_type_alias_value=True,
)
return iter_matching_types(input_type, config)


def _resolve_input_type[I, O](handler_type: HandlerType[[I], O]) -> type[I]:
fake_method = handler_type.handle.__get__(NotImplemented, handler_type)
signature = inspect_signature(fake_method, eval_str=True)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ classifiers = [
dependencies = [
"anyio",
"python-injection",
"type-analyzer",
]

[project.urls]
Expand Down
Loading
Loading