Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Opt for simpler Literal[...] type
Browse files Browse the repository at this point in the history
I would like to see a StrEnum but I am somewhat anxious of making that
change right now. I think we should get validation in place first and
then look at propagating better types into the application.
  • Loading branch information
David Robertson committed Sep 15, 2022
1 parent 8143310 commit 38ae8c7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 23 deletions.
6 changes: 3 additions & 3 deletions synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from urllib.parse import urlparse

from pydantic import StrictBool, StrictStr, constr
from typing_extensions import Literal

from twisted.web.server import Request

Expand Down Expand Up @@ -46,7 +47,6 @@
ClientSecretType,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
ThreepidMedium,
)
from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
Expand Down Expand Up @@ -710,7 +710,7 @@ def __init__(self, hs: "HomeServer"):
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: ThreepidMedium
medium: Literal["email", "msisdn"]

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
Expand Down Expand Up @@ -744,7 +744,7 @@ def __init__(self, hs: "HomeServer"):
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: ThreepidMedium
medium: Literal["email", "msisdn"]

async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
Expand Down
7 changes: 0 additions & 7 deletions synapse/rest/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import TYPE_CHECKING, Dict, Optional

from pydantic import Extra, StrictInt, StrictStr, constr, validator
Expand All @@ -20,12 +19,6 @@
from synapse.util.threepids import validate_email


class ThreepidMedium(str, Enum):
# Per advice at https://pydantic-docs.helpmanual.io/usage/types/#enums-and-choices
email = "email"
msisdn = "msisdn"


class AuthenticationData(RequestBodyModel):
"""
Data used during user-interactive authentication.
Expand Down
16 changes: 3 additions & 13 deletions tests/rest/client/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import unittest as stdlib_unittest

from pydantic import BaseModel, ValidationError
from typing_extensions import Literal

from synapse.rest.client.models import EmailRequestTokenBody, ThreepidMedium
from synapse.rest.client.models import EmailRequestTokenBody


class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
class Model(BaseModel):
medium: ThreepidMedium
medium: Literal["email", "msisdn"]

def test_accepts_valid_medium_string(self) -> None:
"""Sanity check that Pydantic behaves sensibly with an enum-of-str
Expand All @@ -29,18 +30,7 @@ def test_accepts_valid_medium_string(self) -> None:
simultaneously.
"""
model = self.Model.parse_obj({"medium": "email"})
self.assertIsInstance(model.medium, str)
self.assertEqual(model.medium, "email")
self.assertEqual(model.medium, ThreepidMedium.email)
self.assertIs(model.medium, ThreepidMedium.email)

self.assertNotEqual(model.medium, "msisdn")
self.assertNotEqual(model.medium, ThreepidMedium.msisdn)
self.assertIsNot(model.medium, ThreepidMedium.msisdn)

def test_accepts_valid_medium_enum(self) -> None:
model = self.Model.parse_obj({"medium": ThreepidMedium.email})
self.assertIs(model.medium, ThreepidMedium.email)

def test_rejects_invalid_medium_value(self) -> None:
with self.assertRaises(ValidationError):
Expand Down

0 comments on commit 38ae8c7

Please sign in to comment.