Skip to content

Commit

Permalink
Add typing for the ASGI messages
Browse files Browse the repository at this point in the history
This makes use of the TypedDict (added to Python in 3.8 and available
in the typing_extensions package) to specify the types of the key
values in the ASGI messages. This then ensures that the messages are
correctly constructued and used in the code.

Note the type ignores are due to this issue
python/mypy#8533.
  • Loading branch information
pgjones committed Dec 24, 2020
1 parent 5f93a2c commit 39094e7
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 48 deletions.
23 changes: 17 additions & 6 deletions src/hypercorn/asyncio/context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import asyncio
from typing import Any, Awaitable, Callable, Type, Union
from typing import Any, Awaitable, Callable, Optional, Type, Union

from .task_group import TaskGroup
from ..config import Config
from ..typing import ASGIFramework, Event, Scope
from ..typing import (
ASGIFramework,
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendEvent,
Event,
Scope,
)
from ..utils import invoke_asgi


Expand All @@ -22,7 +29,11 @@ async def set(self) -> None:


async def _handle(
app: ASGIFramework, config: Config, scope: Scope, receive: Callable, send: Callable
app: ASGIFramework,
config: Config,
scope: Scope,
receive: ASGIReceiveCallable,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> None:
try:
await invoke_asgi(app, scope, receive, send)
Expand All @@ -45,9 +56,9 @@ async def spawn_app(
app: ASGIFramework,
config: Config,
scope: Scope,
send: Callable[[dict], Awaitable[None]],
) -> Callable[[dict], Awaitable[None]]:
app_queue: asyncio.Queue = asyncio.Queue(config.max_app_queue_size)
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
app_queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue(config.max_app_queue_size)
self.task_group.spawn(_handle(app, config, scope, app_queue.get, send))
return app_queue.put

Expand Down
6 changes: 3 additions & 3 deletions src/hypercorn/asyncio/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio

from ..config import Config
from ..typing import ASGIFramework, LifespanScope
from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope
from ..utils import invoke_asgi, LifespanFailure, LifespanTimeout


Expand Down Expand Up @@ -67,10 +67,10 @@ async def wait_for_shutdown(self) -> None:
except asyncio.TimeoutError as error:
raise LifespanTimeout("shutdown") from error

async def asgi_receive(self) -> dict:
async def asgi_receive(self) -> ASGIReceiveEvent:
return await self.app_queue.get()

async def asgi_send(self, message: dict) -> None:
async def asgi_send(self, message: ASGISendEvent) -> None:
if message["type"] == "lifespan.startup.complete":
self.startup.set()
elif message["type"] == "lifespan.shutdown.complete":
Expand Down
14 changes: 10 additions & 4 deletions src/hypercorn/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

if TYPE_CHECKING:
from .config import Config
from .typing import WWWScope
from .typing import ResponseSummary, WWWScope


def _create_logger(
Expand Down Expand Up @@ -65,7 +65,9 @@ def __init__(self, config: "Config") -> None:
if config.logconfig_dict is not None:
dictConfig(config.logconfig_dict)

async def access(self, request: "WWWScope", response: dict, request_time: float) -> None:
async def access(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> None:
if self.access_logger is not None:
self.access_logger.info(
self.access_log_format, self.atoms(request, response, request_time)
Expand Down Expand Up @@ -99,7 +101,9 @@ async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None
if self.error_logger is not None:
self.error_logger.log(level, message, *args, **kwargs)

def atoms(self, request: "WWWScope", response: dict, request_time: float) -> Mapping[str, str]:
def atoms(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> Mapping[str, str]:
"""Create and return an access log atoms dictionary.
This can be overidden and customised if desired. It should
Expand All @@ -112,7 +116,9 @@ def __getattr__(self, name: str) -> Any:


class AccessLogAtoms(dict):
def __init__(self, request: "WWWScope", response: dict, request_time: float) -> None:
def __init__(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> None:
for name, value in request["headers"]:
self[f"{{{name.decode('latin1').lower()}}}i"] = value.decode("latin1")
for name, value in response.get("headers", []):
Expand Down
8 changes: 4 additions & 4 deletions src/hypercorn/protocol/http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .events import Body, EndBody, Event, Request, Response, StreamClosed
from ..config import Config
from ..typing import ASGIFramework, Context, HTTPScope
from ..typing import ASGIFramework, ASGISendEvent, Context, HTTPResponseStartEvent, HTTPScope
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name

PUSH_VERSIONS = {"2", "3"}
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
self.closed = False
self.config = config
self.context = context
self.response: dict
self.response: HTTPResponseStartEvent
self.scope: HTTPScope
self.send = send
self.scheme = "https" if ssl else "http"
Expand Down Expand Up @@ -91,9 +91,9 @@ async def handle(self, event: Event) -> None:
elif isinstance(event, StreamClosed):
self.closed = True
if self.app_put is not None:
await self.app_put({"type": "http.disconnect"})
await self.app_put({"type": "http.disconnect"}) # type: ignore

async def app_send(self, message: Optional[dict]) -> None:
async def app_send(self, message: Optional[ASGISendEvent]) -> None:
if self.closed:
# Allow app to finish after close
return
Expand Down
20 changes: 14 additions & 6 deletions src/hypercorn/protocol/ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@

from .events import Body, Data, EndBody, EndData, Event, Request, Response, StreamClosed
from ..config import Config
from ..typing import ASGIFramework, Context, WebsocketScope
from ..typing import (
ASGIFramework,
ASGISendEvent,
Context,
WebsocketAcceptEvent,
WebsocketResponseBodyEvent,
WebsocketResponseStartEvent,
WebsocketScope,
)
from ..utils import build_and_validate_headers, suppress_body, UnexpectedMessage, valid_server_name


Expand Down Expand Up @@ -152,7 +160,7 @@ def __init__(
self.closed = False
self.config = config
self.context = context
self.response: dict
self.response: WebsocketResponseStartEvent
self.scope: WebsocketScope
self.send = send
# RFC 8441 for HTTP/2 says use http or https, ASGI says ws or wss
Expand Down Expand Up @@ -202,7 +210,7 @@ async def handle(self, event: Event) -> None:
self.app_put = await self.context.spawn_app(
self.app, self.config, self.scope, self.app_send
)
await self.app_put({"type": "websocket.connect"})
await self.app_put({"type": "websocket.connect"}) # type: ignore
elif isinstance(event, (Body, Data)):
self.connection.receive_data(event.data)
await self._handle_events()
Expand All @@ -215,7 +223,7 @@ async def handle(self, event: Event) -> None:
code = CloseReason.ABNORMAL_CLOSURE.value
await self.app_put({"type": "websocket.disconnect", "code": code})

async def app_send(self, message: Optional[dict]) -> None:
async def app_send(self, message: Optional[ASGISendEvent]) -> None:
if self.closed:
# Allow app to finish after close
return
Expand Down Expand Up @@ -304,7 +312,7 @@ async def _send_wsproto_event(self, event: WSProtoEvent) -> None:
data = self.connection.send(event)
await self.send(Data(stream_id=self.stream_id, data=data))

async def _accept(self, message: dict) -> None:
async def _accept(self, message: WebsocketAcceptEvent) -> None:
self.state = ASGIWebsocketState.CONNECTED
status_code, headers, self.connection = self.handshake.accept(message.get("subprotocol"))
await self.send(
Expand All @@ -316,7 +324,7 @@ async def _accept(self, message: dict) -> None:
if self.config.websocket_ping_interval is not None:
self.context.spawn(self._send_pings)

async def _send_rejection(self, message: dict) -> None:
async def _send_rejection(self, message: WebsocketResponseBodyEvent) -> None:
body_suppressed = suppress_body("GET", self.response["status"])
if self.state == ASGIWebsocketState.HANDSHAKE:
headers = build_and_validate_headers(self.response["headers"])
Expand Down
6 changes: 4 additions & 2 deletions src/hypercorn/statsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

if TYPE_CHECKING:
from .config import Config
from .typing import WWWScope
from .typing import ResponseSummary, WWWScope

METRIC_VAR = "metric"
VALUE_VAR = "value"
Expand Down Expand Up @@ -64,7 +64,9 @@ async def log(self, level: int, message: str, *args: Any, **kwargs: Any) -> None
except Exception:
await super().warning("Failed to log to statsd", exc_info=True)

async def access(self, request: "WWWScope", response: dict, request_time: float) -> None:
async def access(
self, request: "WWWScope", response: "ResponseSummary", request_time: float
) -> None:
await super().access(request, response, request_time)
await self.histogram("hypercorn.request.duration", request_time * 1_000)
await self.increment("hypercorn.requests", 1)
Expand Down
21 changes: 16 additions & 5 deletions src/hypercorn/trio/context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import Any, Awaitable, Callable, Type, Union
from typing import Any, Awaitable, Callable, Optional, Type, Union

import trio

from ..config import Config
from ..typing import ASGIFramework, Event, Scope
from ..typing import (
ASGIFramework,
ASGIReceiveCallable,
ASGIReceiveEvent,
ASGISendEvent,
Event,
Scope,
)
from ..utils import invoke_asgi


Expand All @@ -22,7 +29,11 @@ async def set(self) -> None:


async def _handle(
app: ASGIFramework, config: Config, scope: Scope, receive: Callable, send: Callable
app: ASGIFramework,
config: Config,
scope: Scope,
receive: ASGIReceiveCallable,
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> None:
try:
await invoke_asgi(app, scope, receive, send)
Expand Down Expand Up @@ -54,8 +65,8 @@ async def spawn_app(
app: ASGIFramework,
config: Config,
scope: Scope,
send: Callable[[dict], Awaitable[None]],
) -> Callable[[dict], Awaitable[None]]:
send: Callable[[Optional[ASGISendEvent]], Awaitable[None]],
) -> Callable[[ASGIReceiveEvent], Awaitable[None]]:
app_send_channel, app_receive_channel = trio.open_memory_channel(config.max_app_queue_size)
self.nursery.start_soon(_handle, app, config, scope, app_receive_channel.receive, send)
return app_send_channel.send
Expand Down
6 changes: 3 additions & 3 deletions src/hypercorn/trio/lifespan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import trio

from ..config import Config
from ..typing import ASGIFramework, LifespanScope
from ..typing import ASGIFramework, ASGIReceiveEvent, ASGISendEvent, LifespanScope
from ..utils import invoke_asgi, LifespanFailure, LifespanTimeout


Expand Down Expand Up @@ -63,10 +63,10 @@ async def wait_for_shutdown(self) -> None:
except trio.TooSlowError as error:
raise LifespanTimeout("startup") from error

async def asgi_receive(self) -> dict:
async def asgi_receive(self) -> ASGIReceiveEvent:
return await self.app_receive_channel.receive()

async def asgi_send(self, message: dict) -> None:
async def asgi_send(self, message: ASGISendEvent) -> None:
if message["type"] == "lifespan.startup.complete":
self.startup.set()
elif message["type"] == "lifespan.shutdown.complete":
Expand Down
Loading

0 comments on commit 39094e7

Please sign in to comment.