Skip to content

Commit

Permalink
release 0.17.0, add derive and more utilles in dispatcher interface
Browse files Browse the repository at this point in the history
  • Loading branch information
GreyElaina committed Jun 6, 2022
1 parent d63a7ef commit c1ece2a
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 180 deletions.
22 changes: 9 additions & 13 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
{
"python.linting.enabled": false,
"python.linting.pylintEnabled": true,
"python.linting.banditEnabled": false,
"maven.view": "hierarchical",
"python.formatting.provider": "black",
"python.pythonPath": "C:\\Users\\Chenw\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\graia-broadcast-kkIP7ti5-py3.8",
"jupyter.jupyterServerType": "local",
"cSpell.words": [
"oplog",
"Unexisted",
"utilles"
]
}
"python.linting.enabled": false,
"python.linting.pylintEnabled": true,
"python.linting.banditEnabled": false,
"maven.view": "hierarchical",
"python.formatting.provider": "black",
"python.pythonPath": "C:\\Users\\Chenw\\AppData\\Local\\pypoetry\\Cache\\virtualenvs\\graia-broadcast-kkIP7ti5-py3.8",
"jupyter.jupyterServerType": "local",
"cSpell.words": ["Dispatchable", "oplog", "Unexisted", "utilles"]
}
184 changes: 60 additions & 124 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
[tool.poetry]
name = "graia-broadcast"
version = "0.16.1"
version = "0.17.0"
description = "a highly customizable, elegantly designed event system based on asyncio"
authors = ["GreyElaina <GreyElaina@outlook.com>"]
license = "MIT"
packages = [{ include = "graia", from = "src" }]

[tool.poetry.dependencies]
python = "^3.7"
typing-extensions = { version = "^3.10.0", python = "~3.7" }
typing-extensions = { version = "^3.10.0", python = "<3.9" }

[tool.poetry.dev-dependencies]
black = "^22.1.0"
pre-commit = "*"
flake8 = "^4.0.1"
isort = "^5.10.1"
yappi = "^1.3.2"
pyinstrument = "^4.0.4"
pytest = "^7.0.1"
coverage = "^6.3.2"
pytest-asyncio = "^0.18.2"
Expand Down
91 changes: 61 additions & 30 deletions src/graia/broadcast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import sys
import traceback
from contextlib import asynccontextmanager
from typing import Callable, Dict, Iterable, List, Optional, Type, Union

from graia.broadcast.builtin.derive import DeriveDispatcher
from graia.broadcast.entities.dispatcher import BaseDispatcher

from .builtin.event import ExceptionThrowed
Expand Down Expand Up @@ -64,7 +66,7 @@ def __init__(
self.listeners = []
self.event_ctx = Ctx("bcc_event_ctx")
self.decorator_interface = DecoratorInterface()
self.prelude_dispatchers = [self.decorator_interface]
self.prelude_dispatchers = [self.decorator_interface, DeriveDispatcher()]
self.finale_dispatchers = []

@self.prelude_dispatchers.append
Expand Down Expand Up @@ -132,11 +134,8 @@ async def Executor(
current_oplog = target.oplog.setdefault(..., {})
# also, Ellipsis is good.

if is_listener:
if target.namespace.disabled: # type: ignore
raise DisabledNamespace(
"caught a disabled namespace: {0}".format(target.namespace.name) # type: ignore
)
if is_listener and target.namespace.disabled: # type: ignore
raise DisabledNamespace("caught a disabled namespace: {0}".format(target.namespace.name)) # type: ignore

target_callable = target.callable if is_exectarget else target # type: ignore
parameter_compile_result = {}
Expand Down Expand Up @@ -214,6 +213,49 @@ async def Executor(
self.listeners.remove(target)
return result

@asynccontextmanager
async def param_compile(
self,
dispatchers: Optional[List[T_Dispatcher]] = None,
post_exception_event: bool = True,
print_exception: bool = True,
use_global_dispatchers: bool = True,
):
dispatchers = [
*(self.prelude_dispatchers if use_global_dispatchers else []),
*(dispatchers if dispatchers else []),
*(self.finale_dispatchers if use_global_dispatchers else []),
]

dii = DispatcherInterface(self, dispatchers)
dii_token = dii.ctx.set(dii)

try:
for dispatcher in dispatchers:
i = getattr(dispatcher, "beforeExecution", None)
if i:
await run_always_await_safely(i, dii, exception, tb) # type: ignore
yield dii
except RequirementCrashed:
traceback.print_exc()
raise
except Exception as e:
event: Optional[Dispatchable] = self.event_ctx.get()
if event is not None and event.__class__ is not ExceptionThrowed:
if print_exception:
traceback.print_exc()
if post_exception_event:
self.postEvent(ExceptionThrowed(exception=e, event=event))
raise
finally:
_, exception, tb = sys.exc_info()
for dispatcher in dispatchers:
i = getattr(dispatcher, "afterExecution", None)
if i:
await run_always_await_safely(i, dii, exception, tb) # type: ignore

dii.ctx.reset(dii_token)

def postEvent(self, event: Dispatchable, upper_event: Optional[Dispatchable] = None):
return self.loop.create_task(
self.layered_scheduler(
Expand Down Expand Up @@ -250,29 +292,22 @@ def createNamespace(self, name, *, priority: int = 0, hide: bool = False, disabl
return self.namespaces[-1]

def removeNamespace(self, name):
if self.containNamespace(name):
for index, i in enumerate(self.namespaces):
if i.name == name:
self.namespaces.pop(index)
return
else:
if not self.containNamespace(name):
raise UnexistedNamespace(name)
for index, i in enumerate(self.namespaces):
if i.name == name:
self.namespaces.pop(index)
return

def containNamespace(self, name):
for i in self.namespaces:
if i.name == name:
return True
return False
return any(i.name == name for i in self.namespaces)

def getNamespace(self, name) -> "Namespace":
if self.containNamespace(name):
for i in self.namespaces:
if i.name == name:
return i
else:
raise UnexistedNamespace(name)
else:
raise UnexistedNamespace(name)
raise UnexistedNamespace(name)

def hideNamespace(self, name):
ns = self.getNamespace(name)
Expand All @@ -291,10 +326,7 @@ def enableNamespace(self, name):
ns.disabled = False

def containListener(self, target):
for i in self.listeners:
if i.callable == target:
return True
return False
return any(i.callable == target for i in self.listeners)

def getListener(self, target):
for i in self.listeners:
Expand All @@ -320,8 +352,8 @@ def receiver(
priority = int(priority)

def receiver_wrapper(callable_target):
may_listener = self.getListener(callable_target)
if not may_listener:
listener = self.getListener(callable_target)
if not listener:
self.listeners.append(
Listener(
callable=callable_target,
Expand All @@ -332,11 +364,10 @@ def receiver_wrapper(callable_target):
decorators=decorators or [],
)
)
elif event in listener.listening_events:
raise RegisteredEventListener(event.__name__, "has been registered!") # type: ignore
else:
if event not in may_listener.listening_events:
may_listener.listening_events.append(event) # type: ignore
else:
raise RegisteredEventListener(event.__name__, "has been registered!") # type: ignore
listener.listening_events.append(event) # type: ignore
return callable_target

return receiver_wrapper
31 changes: 31 additions & 0 deletions src/graia/broadcast/builtin/derive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

from typing import Protocol, TypeVar

from graia.broadcast.entities.dispatcher import BaseDispatcher
from graia.broadcast.interfaces.dispatcher import DispatcherInterface

try:
from typing import get_args
except ImportError:
from typing_extensions import get_args


T = TypeVar("T")


class Derive(Protocol[T]):
async def __call__(self, value: T, dispatcher_interface: DispatcherInterface) -> T:
...


class DeriveDispatcher(BaseDispatcher):
async def catch(self, interface: DispatcherInterface):
if not interface.is_annotated:
return
args = get_args(interface.annotation)
origin_arg, meta = args[0], args[1:]
result = await interface.lookup_param(interface.name, origin_arg, interface.default)
for i in meta:
result = await i(result, interface)
return result
63 changes: 54 additions & 9 deletions src/graia/broadcast/interfaces/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,29 @@
Set,
Tuple,
TypeVar,
Union,
)

from ..entities.dispatcher import BaseDispatcher
from ..entities.event import Dispatchable
from ..entities.signatures import Force
from ..exceptions import RequirementCrashed
from ..exceptions import ExecutionStop, RequirementCrashed
from ..typing import T_Dispatcher
from ..utilles import Ctx, NestableIterable

try:
from typing import get_args, get_origin
except ImportError:
from typing_extensions import get_args, get_origin

try:
from typing import Annotated

from typing_extensions import Annotated as AlterAnnotated
except ImportError:
from typing_extensions import Annotated

AlterAnnotated = None

if TYPE_CHECKING:
from .. import Broadcast

Expand Down Expand Up @@ -71,10 +85,42 @@ def default(self) -> Any:
def event(self) -> T_Event:
return self.broadcast.event_ctx.get() # type: ignore

@property
def is_optional(self) -> bool:
anno = self.annotation
return get_origin(anno) is Union and None in get_args(anno)

@property
def is_annotated(self) -> bool:
return get_origin(self.annotation) in {Annotated, AlterAnnotated}

@property
def annotated_origin(self) -> Any:
if not self.is_annotated:
raise TypeError("required a annotated annotation")
return get_args(self.annotation)[0]

@property
def annotated_metadata(self) -> tuple:
if not self.is_annotated:
raise TypeError("required a annotated annotation")
return get_args(self.annotation)[1:]

def inject_execution_raw(self, *dispatchers: T_Dispatcher):
for dispatcher in dispatchers:
self.dispatchers.insert(0, dispatcher)

def crash(self):
raise RequirementCrashed(
"the dispatching requirement crashed: ",
self.name,
self.annotation,
self.default,
)

def stop(self):
raise ExecutionStop

async def lookup_param(
self,
name: str,
Expand Down Expand Up @@ -109,13 +155,12 @@ async def lookup_param(
return result.target

return result
else:
raise RequirementCrashed(
"the dispatching requirement crashed: ",
self.name,
self.annotation,
self.default,
)
raise RequirementCrashed(
"the dispatching requirement crashed: ",
self.name,
self.annotation,
self.default,
)
finally:
self.parameter_contexts.pop()

Expand Down
35 changes: 35 additions & 0 deletions src/test_derive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import asyncio
from typing import Annotated

from graia.broadcast import Broadcast, Dispatchable
from graia.broadcast.entities.dispatcher import BaseDispatcher
from graia.broadcast.interfaces.dispatcher import DispatcherInterface


class ExampleEvent(Dispatchable):
class Dispatcher(BaseDispatcher):
@staticmethod
async def catch(interface: "DispatcherInterface"):
if interface.annotation is str:
return "ok, i'm."


loop = asyncio.get_event_loop()
broadcast = Broadcast(loop=loop)


async def test_derive_1(v: str, dii: DispatcherInterface):
print("in derive 1", v)
return v[1:]


@broadcast.receiver("ExampleEvent") # or just receiver(ExampleEvent)
async def event_listener(maybe_you_are_str: Annotated[str, test_derive_1, test_derive_1]):
print(maybe_you_are_str) # <<< ok, i'm


async def main():
await broadcast.postEvent(ExampleEvent()) # sync call is allowed.


loop.run_until_complete(main())

0 comments on commit c1ece2a

Please sign in to comment.