Skip to content

Commit

Permalink
fix is_optional
Browse files Browse the repository at this point in the history
add tests for derive
  • Loading branch information
BlueGlassBlock committed Jun 24, 2022
1 parent 631c017 commit a6e4786
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,5 @@ exclude_lines = [
"def __repr__",
"def __str__",
"def __repr_args__",
"except ImportError:", # Don't complain about import fallback
]
2 changes: 1 addition & 1 deletion src/graia/broadcast/interfaces/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def event(self) -> T_Event:
@property
def is_optional(self) -> bool:
anno = self.annotation
return get_origin(anno) is Union and None in get_args(anno)
return get_origin(anno) is Union and type(None) in get_args(anno)

@property
def is_annotated(self) -> bool:
Expand Down
46 changes: 36 additions & 10 deletions src/test/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,30 @@
from graia.broadcast.entities.dispatcher import BaseDispatcher
from graia.broadcast.entities.event import Dispatchable
from graia.broadcast.entities.signatures import Force
from graia.broadcast.exceptions import ExecutionStop, RequirementCrashed
from graia.broadcast.interfaces.dispatcher import DispatcherInterface


class TestDispatcher(BaseDispatcher):
@staticmethod
async def catch(interface: DispatcherInterface):
class RandomDispatcher(BaseDispatcher):
def __init__(self, t: bool = False) -> None:
self.second_exec = t

async def catch(self, interface: DispatcherInterface):
if interface.name == "p":
return "P_dispatcher"
elif interface.name == "f":
if self.second_exec:
return
self.second_exec = True
return Force(2)


class CrashDispatcher(BaseDispatcher):
@staticmethod
async def catch(i: DispatcherInterface):
i.crash()


class TestEvent(Dispatchable):
class Dispatcher(BaseDispatcher):
@staticmethod
Expand All @@ -36,8 +48,21 @@ async def test_lookup_directly():

dii = DispatcherInterface(bcc, [])

assert await dii.lookup_by_directly(TestDispatcher, "p", None, None) == "P_dispatcher"
assert await dii.lookup_by_directly(TestDispatcher, "f", None, None) == 2
assert await dii.lookup_by_directly(RandomDispatcher(), "p", None, None) == "P_dispatcher"
assert await dii.lookup_by_directly(RandomDispatcher(), "f", None, None) == 2


@pytest.mark.asyncio
async def test_crash():
bcc = Broadcast(
loop=asyncio.get_running_loop(),
)

dii = DispatcherInterface(bcc, [])
with pytest.raises(RequirementCrashed):
await dii.lookup_by_directly(CrashDispatcher, "u", None, None)
with pytest.raises(ExecutionStop):
dii.stop()


@pytest.mark.asyncio
Expand All @@ -48,8 +73,8 @@ async def test_insert():

dii = DispatcherInterface(bcc, [])

t_a = TestDispatcher()
t_b = TestDispatcher()
t_a = RandomDispatcher()
t_b = RandomDispatcher()
dii.inject_execution_raw(t_a, t_b)
assert dii.dispatchers == [t_b, t_a]

Expand Down Expand Up @@ -83,16 +108,17 @@ async def test_dispatcher_catch():

executed = []

@bcc.receiver(TestEvent, dispatchers=[TestDispatcher])
@bcc.receiver(TestEvent, dispatchers=[RandomDispatcher(), RandomDispatcher()])
async def _1(f, b: Broadcast, i: DispatcherInterface):
assert f == 2
assert b is bcc
assert i.__class__ == DispatcherInterface
executed.append(1)

await bcc.postEvent(TestEvent())
await bcc.postEvent(TestEvent())

assert len(executed) == 1
assert len(executed) == 2


@pytest.mark.asyncio
Expand All @@ -104,7 +130,7 @@ async def test_dispatcher_priority():

executed = []

@bcc.receiver(TestEvent, dispatchers=[TestDispatcher])
@bcc.receiver(TestEvent, dispatchers=[RandomDispatcher])
async def _2(ster, p):
assert ster == "1"
assert p == "P_dispatcher"
Expand Down

0 comments on commit a6e4786

Please sign in to comment.