Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
A second batch of Pydantic models for rest/client/account.py (#13687)
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson authored Sep 7, 2022
1 parent d3d9ca1 commit b58386e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 34 deletions.
1 change: 1 addition & 0 deletions changelog.d/13687.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/msisdn/requestToken`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidmsisdnrequesttoken) and [`/org.matrix.msc3720/account_status`](https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/user_status/proposals/3720-account-status.md#post-_matrixclientv1account_status).
19 changes: 17 additions & 2 deletions synapse/http/servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
overload,
)

from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
from pydantic.error_wrappers import ErrorWrapper
from typing_extensions import Literal

from twisted.web.server import Request
Expand Down Expand Up @@ -714,7 +715,21 @@ def parse_and_validate_json_object_from_request(
try:
instance = model_type.parse_obj(content)
except ValidationError as e:
raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON)
# Choose a matrix error code. The catch-all is BAD_JSON, but we try to find a
# more specific error if possible (which occasionally helps us to be spec-
# compliant) This is a bit awkward because the spec's error codes aren't very
# clear-cut: BAD_JSON arguably overlaps with MISSING_PARAM and INVALID_PARAM.
errcode = Codes.BAD_JSON

raw_errors = e.raw_errors
if len(raw_errors) == 1 and isinstance(raw_errors[0], ErrorWrapper):
raw_error = raw_errors[0].exc
if isinstance(raw_error, MissingError):
errcode = Codes.MISSING_PARAM
elif isinstance(raw_error, PydanticValueError):
errcode = Codes.INVALID_PARAM

raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=errcode)

return instance

Expand Down
54 changes: 26 additions & 28 deletions synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse

from pydantic import StrictBool, StrictStr, constr
Expand All @@ -41,7 +41,11 @@
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.rest.client.models import AuthenticationData, EmailRequestTokenBody
from synapse.rest.client.models import (
AuthenticationData,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
)
from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
Expand Down Expand Up @@ -400,23 +404,16 @@ def __init__(self, hs: "HomeServer"):
self.identity_handler = hs.get_identity_handler()

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
body = parse_and_validate_json_object_from_request(
request, MsisdnRequestTokenBody
)
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)

country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param

msisdn = phone_number_to_msisdn(country, phone_number)
msisdn = phone_number_to_msisdn(body.country, body.phone_number)

if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403,
# TODO: is this error message accurate? Looks like we've only rejected
# this phone number, not necessarily all phone numbers
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
Expand All @@ -425,9 +422,9 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
request, "msisdn", msisdn
)

if next_link:
if body.next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
assert_valid_next_link(self.hs, body.next_link)

existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)

Expand All @@ -454,15 +451,15 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.registration.account_threepid_delegate_msisdn,
country,
phone_number,
client_secret,
send_attempt,
next_link,
body.country,
body.phone_number,
body.client_secret,
body.send_attempt,
body.next_link,
)

threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
send_attempt
body.send_attempt
)

return 200, ret
Expand Down Expand Up @@ -845,17 +842,18 @@ def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._account_handler = hs.get_account_handler()

class PostBody(RequestBodyModel):
# TODO: we could validate that each user id is an mxid here, and/or parse it
# as a UserID
user_ids: List[StrictStr]

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self._auth.get_user_by_req(request)

body = parse_json_object_from_request(request)
if "user_ids" not in body:
raise SynapseError(
400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
)
body = parse_and_validate_json_object_from_request(request, self.PostBody)

statuses, failures = await self._account_handler.get_account_statuses(
body["user_ids"],
body.user_ids,
allow_remote=True,
)

Expand Down
24 changes: 20 additions & 4 deletions synapse/rest/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class AuthenticationData(RequestBodyModel):
(The name "Authentication Data" is taken directly from the spec.)
Additional keys will be present, depending on the `type` field. Use `.dict()` to
access them.
Additional keys will be present, depending on the `type` field. Use
`.dict(exclude_unset=True)` to access them.
"""

class Config:
Expand All @@ -36,7 +36,7 @@ class Config:
type: Optional[StrictStr] = None


class EmailRequestTokenBody(RequestBodyModel):
class ThreePidRequestTokenBody(RequestBodyModel):
if TYPE_CHECKING:
client_secret: StrictStr
else:
Expand All @@ -47,7 +47,7 @@ class EmailRequestTokenBody(RequestBodyModel):
max_length=255,
strict=True,
)
email: StrictStr

id_server: Optional[StrictStr]
id_access_token: Optional[StrictStr]
next_link: Optional[StrictStr]
Expand All @@ -61,9 +61,25 @@ def token_required_for_identity_server(
raise ValueError("id_access_token is required if an id_server is supplied.")
return token


class EmailRequestTokenBody(ThreePidRequestTokenBody):
email: StrictStr

# Canonicalise the email address. The addresses are all stored canonicalised
# in the database. This allows the user to reset his password without having to
# know the exact spelling (eg. upper and lower case) of address in the database.
# Without this, an email stored in the database as "foo@bar.com" would cause
# user requests for "FOO@bar.com" to raise a Not Found error.
_email_validator = validator("email", allow_reuse=True)(validate_email)


if TYPE_CHECKING:
ISO3116_1_Alpha_2 = StrictStr
else:
# Per spec: two-letter uppercase ISO-3166-1-alpha-2
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)


class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
country: ISO3116_1_Alpha_2
phone_number: StrictStr

0 comments on commit b58386e

Please sign in to comment.