Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type annotations for the pyramid and tornado parsers #944

Merged
merged 3 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,3 @@ disallow_untyped_defs = false

[mypy-webargs.falconparser]
disallow_untyped_defs = false

[mypy-webargs.pyramidparser]
disallow_untyped_defs = false

[mypy-webargs.tornadoparser]
disallow_untyped_defs = false
2 changes: 1 addition & 1 deletion src/webargs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
)
return func(*args, **kwargs)

wrapper.__wrapped__ = func # type: ignore
wrapper.__wrapped__ = func
_record_arg_name(wrapper, arg_name)
return wrapper

Expand Down
61 changes: 39 additions & 22 deletions src/webargs/pyramidparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def hello_world(request, args):
from __future__ import annotations

import functools
import typing
from collections.abc import Mapping

import marshmallow as ma
Expand All @@ -38,6 +39,8 @@ def hello_world(request, args):
from webargs import core
from webargs.core import json

F = typing.TypeVar("F", bound=typing.Callable)


def is_json_request(req: Request) -> bool:
return core.is_json(req.headers.get("content-type"))
Expand All @@ -57,7 +60,7 @@ class PyramidParser(core.Parser[Request]):
**core.Parser.__location_map__,
)

def _raw_load_json(self, req: Request):
def _raw_load_json(self, req: Request) -> typing.Any:
"""Return a json payload from the request for the core parser's load_json

Checks the input mimetype and may return 'missing' if the mimetype is
Expand All @@ -67,34 +70,40 @@ def _raw_load_json(self, req: Request):

return core.parse_json(req.body, encoding=req.charset)

def load_querystring(self, req: Request, schema):
def load_querystring(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return query params from the request as a MultiDictProxy."""
return self._makeproxy(req.GET, schema)

def load_form(self, req: Request, schema):
def load_form(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return form values from the request as a MultiDictProxy."""
return self._makeproxy(req.POST, schema)

def load_cookies(self, req: Request, schema):
def load_cookies(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return cookies from the request as a MultiDictProxy."""
return self._makeproxy(req.cookies, schema)

def load_headers(self, req: Request, schema):
def load_headers(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return headers from the request as a MultiDictProxy."""
return self._makeproxy(req.headers, schema)

def load_files(self, req: Request, schema):
def load_files(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return files from the request as a MultiDictProxy."""
files = ((k, v) for k, v in req.POST.items() if hasattr(v, "file"))
return self._makeproxy(MultiDict(files), schema)

def load_matchdict(self, req: Request, schema):
def load_matchdict(self, req: Request, schema: ma.Schema) -> typing.Any:
"""Return the request's ``matchdict`` as a MultiDictProxy."""
return self._makeproxy(req.matchdict, schema)

def handle_error(
self, error, req: Request, schema, *, error_status_code, error_headers
):
self,
error: ma.ValidationError,
req: Request,
schema: ma.Schema,
*,
error_status_code: int | None,
error_headers: typing.Mapping[str, str] | None,
) -> typing.NoReturn:
"""Handles errors during parsing. Aborts the current HTTP request and
responds with a 400 error.
"""
Expand All @@ -109,7 +118,13 @@ def handle_error(
response.body = body.encode("utf-8") if isinstance(body, str) else body
raise response

def _handle_invalid_json_error(self, error, req: Request, *args, **kwargs):
def _handle_invalid_json_error(
self,
error: json.JSONDecodeError | UnicodeDecodeError,
req: Request,
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.NoReturn:
messages = {"json": ["Invalid JSON body."]}
response = exception_response(
400, detail=str(messages), content_type="application/json"
Expand All @@ -120,17 +135,17 @@ def _handle_invalid_json_error(self, error, req: Request, *args, **kwargs):

def use_args(
self,
argmap,
argmap: core.ArgMap,
req: Request | None = None,
*,
location=core.Parser.DEFAULT_LOCATION,
unknown=None,
as_kwargs=False,
arg_name=None,
validate=None,
error_status_code=None,
error_headers=None,
):
location: str | None = core.Parser.DEFAULT_LOCATION,
unknown: str | None = None,
as_kwargs: bool = False,
arg_name: str | None = None,
validate: core.ValidateArg = None,
error_status_code: int | None = None,
error_headers: typing.Mapping[str, str] | None = None,
) -> typing.Callable[..., typing.Callable]:
"""Decorator that injects parsed arguments into a view callable.
Supports the *Class-based View* pattern where `request` is saved as an instance
attribute on a view class.
Expand Down Expand Up @@ -167,9 +182,11 @@ def use_args(
argmap = dict(argmap)
argmap = self.schema_class.from_dict(argmap)()

def decorator(func):
def decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(obj, *args, **kwargs):
def wrapper(
obj: typing.Any, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
# The first argument is either `self` or `request`
try: # get self.request
request = req or obj.request
Expand All @@ -191,7 +208,7 @@ def wrapper(obj, *args, **kwargs):
return func(obj, *args, **kwargs)

wrapper.__wrapped__ = func
return wrapper
return wrapper # type: ignore[return-value]

return decorator

Expand Down
49 changes: 34 additions & 15 deletions src/webargs/tornadoparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ def get(self, args):
self.write(response)
"""

import json
import typing

import marshmallow as ma
import tornado.concurrent
import tornado.web
from tornado.escape import _unicode
Expand All @@ -26,13 +30,13 @@ def get(self, args):
class HTTPError(tornado.web.HTTPError):
"""`tornado.web.HTTPError` that stores validation errors."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
self.messages = kwargs.pop("messages", {})
self.headers = kwargs.pop("headers", None)
super().__init__(*args, **kwargs)


def is_json_request(req: HTTPServerRequest):
def is_json_request(req: HTTPServerRequest) -> bool:
content_type = req.headers.get("Content-Type")
return content_type is not None and core.is_json(content_type)

Expand All @@ -43,7 +47,7 @@ class WebArgsTornadoMultiDictProxy(MultiDictProxy):
requirements.
"""

def __getitem__(self, key):
def __getitem__(self, key: str) -> typing.Any:
try:
value = self.data.get(key, core.missing)
if value is core.missing:
Expand All @@ -70,7 +74,7 @@ class WebArgsTornadoCookiesMultiDictProxy(MultiDictProxy):
Also, does not use the `_unicode` decoding step
"""

def __getitem__(self, key):
def __getitem__(self, key: str) -> typing.Any:
cookie = self.data.get(key, core.missing)
if cookie is core.missing:
return core.missing
Expand All @@ -82,7 +86,7 @@ def __getitem__(self, key):
class TornadoParser(core.Parser[HTTPServerRequest]):
"""Tornado request argument parser."""

def _raw_load_json(self, req: HTTPServerRequest):
def _raw_load_json(self, req: HTTPServerRequest) -> typing.Any:
"""Return a json payload from the request for the core parser's load_json

Checks the input mimetype and may return 'missing' if the mimetype is
Expand All @@ -97,37 +101,43 @@ def _raw_load_json(self, req: HTTPServerRequest):

return core.parse_json(req.body)

def load_querystring(self, req: HTTPServerRequest, schema):
def load_querystring(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return query params from the request as a MultiDictProxy."""
return self._makeproxy(
req.query_arguments, schema, cls=WebArgsTornadoMultiDictProxy
)

def load_form(self, req: HTTPServerRequest, schema):
def load_form(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return form values from the request as a MultiDictProxy."""
return self._makeproxy(
req.body_arguments, schema, cls=WebArgsTornadoMultiDictProxy
)

def load_headers(self, req: HTTPServerRequest, schema):
def load_headers(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return headers from the request as a MultiDictProxy."""
return self._makeproxy(req.headers, schema, cls=WebArgsTornadoMultiDictProxy)

def load_cookies(self, req: HTTPServerRequest, schema):
def load_cookies(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return cookies from the request as a MultiDictProxy."""
# use the specialized subclass specifically for handling Tornado
# cookies
return self._makeproxy(
req.cookies, schema, cls=WebArgsTornadoCookiesMultiDictProxy
)

def load_files(self, req: HTTPServerRequest, schema):
def load_files(self, req: HTTPServerRequest, schema: ma.Schema) -> typing.Any:
"""Return files from the request as a MultiDictProxy."""
return self._makeproxy(req.files, schema, cls=WebArgsTornadoMultiDictProxy)

def handle_error(
self, error, req: HTTPServerRequest, schema, *, error_status_code, error_headers
):
self,
error: ma.ValidationError,
req: HTTPServerRequest,
schema: ma.Schema,
*,
error_status_code: int | None,
error_headers: typing.Mapping[str, str] | None,
) -> typing.NoReturn:
"""Handles errors during parsing. Raises a `tornado.web.HTTPError`
with a 400 error.
"""
Expand All @@ -145,16 +155,25 @@ def handle_error(
)

def _handle_invalid_json_error(
self, error, req: HTTPServerRequest, *args, **kwargs
):
self,
error: json.JSONDecodeError | UnicodeDecodeError,
req: HTTPServerRequest,
*args: typing.Any,
**kwargs: typing.Any,
) -> typing.NoReturn:
raise HTTPError(
400,
log_message="Invalid JSON body.",
reason="Bad Request",
messages={"json": ["Invalid JSON body."]},
)

def get_request_from_view_args(self, view, args, kwargs):
def get_request_from_view_args(
self,
view: typing.Any,
args: tuple[typing.Any, ...],
kwargs: typing.Mapping[str, typing.Any],
) -> HTTPServerRequest:
return args[0].request


Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ commands = pre-commit run --all-files
# `webargs` and `marshmallow` both installed is a valuable safeguard against
# issues in which `mypy` running on every file standalone won't catch things
[testenv:mypy]
deps = mypy==1.8.0
deps = mypy==1.10.0
extras = frameworks
commands = mypy src/ {posargs}

Expand Down
Loading