diff --git a/pyproject.toml b/pyproject.toml index fae02e8..6d39c8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,4 +57,5 @@ exclude_lines = [ "def __repr__", "def __str__", "def __repr_args__", + "except ImportError:", # Don't complain about import fallback ] diff --git a/src/graia/broadcast/interfaces/dispatcher.py b/src/graia/broadcast/interfaces/dispatcher.py index 1f1a0f9..48a89eb 100644 --- a/src/graia/broadcast/interfaces/dispatcher.py +++ b/src/graia/broadcast/interfaces/dispatcher.py @@ -84,7 +84,7 @@ def event(self) -> T_Event: @property def is_optional(self) -> bool: anno = self.annotation - return get_origin(anno) is Union and None in get_args(anno) + return get_origin(anno) is Union and type(None) in get_args(anno) @property def is_annotated(self) -> bool: diff --git a/src/test/dispatch.py b/src/test/dispatch.py index 8e92ae9..ac498f6 100644 --- a/src/test/dispatch.py +++ b/src/test/dispatch.py @@ -6,18 +6,30 @@ from graia.broadcast.entities.dispatcher import BaseDispatcher from graia.broadcast.entities.event import Dispatchable from graia.broadcast.entities.signatures import Force +from graia.broadcast.exceptions import ExecutionStop, RequirementCrashed from graia.broadcast.interfaces.dispatcher import DispatcherInterface -class TestDispatcher(BaseDispatcher): - @staticmethod - async def catch(interface: DispatcherInterface): +class RandomDispatcher(BaseDispatcher): + def __init__(self, t: bool = False) -> None: + self.second_exec = t + + async def catch(self, interface: DispatcherInterface): if interface.name == "p": return "P_dispatcher" elif interface.name == "f": + if self.second_exec: + return + self.second_exec = True return Force(2) +class CrashDispatcher(BaseDispatcher): + @staticmethod + async def catch(i: DispatcherInterface): + i.crash() + + class TestEvent(Dispatchable): class Dispatcher(BaseDispatcher): @staticmethod @@ -36,8 +48,21 @@ async def test_lookup_directly(): dii = DispatcherInterface(bcc, []) - assert await dii.lookup_by_directly(TestDispatcher, "p", None, None) == "P_dispatcher" - assert await dii.lookup_by_directly(TestDispatcher, "f", None, None) == 2 + assert await dii.lookup_by_directly(RandomDispatcher(), "p", None, None) == "P_dispatcher" + assert await dii.lookup_by_directly(RandomDispatcher(), "f", None, None) == 2 + + +@pytest.mark.asyncio +async def test_crash(): + bcc = Broadcast( + loop=asyncio.get_running_loop(), + ) + + dii = DispatcherInterface(bcc, []) + with pytest.raises(RequirementCrashed): + await dii.lookup_by_directly(CrashDispatcher, "u", None, None) + with pytest.raises(ExecutionStop): + dii.stop() @pytest.mark.asyncio @@ -48,8 +73,8 @@ async def test_insert(): dii = DispatcherInterface(bcc, []) - t_a = TestDispatcher() - t_b = TestDispatcher() + t_a = RandomDispatcher() + t_b = RandomDispatcher() dii.inject_execution_raw(t_a, t_b) assert dii.dispatchers == [t_b, t_a] @@ -83,7 +108,7 @@ async def test_dispatcher_catch(): executed = [] - @bcc.receiver(TestEvent, dispatchers=[TestDispatcher]) + @bcc.receiver(TestEvent, dispatchers=[RandomDispatcher(), RandomDispatcher()]) async def _1(f, b: Broadcast, i: DispatcherInterface): assert f == 2 assert b is bcc @@ -91,8 +116,9 @@ async def _1(f, b: Broadcast, i: DispatcherInterface): executed.append(1) await bcc.postEvent(TestEvent()) + await bcc.postEvent(TestEvent()) - assert len(executed) == 1 + assert len(executed) == 2 @pytest.mark.asyncio @@ -104,7 +130,7 @@ async def test_dispatcher_priority(): executed = [] - @bcc.receiver(TestEvent, dispatchers=[TestDispatcher]) + @bcc.receiver(TestEvent, dispatchers=[RandomDispatcher]) async def _2(ster, p): assert ster == "1" assert p == "P_dispatcher"