Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use HTTPStatus constants in place of literals in tests #13297

Merged
merged 3 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13297.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use `HTTPStatus` constants in place of literals in tests.
5 changes: 3 additions & 2 deletions tests/federation/test_complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from http import HTTPStatus
from unittest.mock import Mock

from synapse.api.errors import Codes, SynapseError
Expand Down Expand Up @@ -50,7 +51,7 @@ def test_complexity_simple(self):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(200, channel.code)
self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)

Expand All @@ -62,7 +63,7 @@ def test_complexity_simple(self):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(200, channel.code)
self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)

Expand Down
11 changes: 6 additions & 5 deletions tests/federation/test_federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from http import HTTPStatus

from parameterized import parameterized

Expand Down Expand Up @@ -58,7 +59,7 @@ def test_bad_request(self, query_content):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")


Expand Down Expand Up @@ -119,7 +120,7 @@ def test_needs_to_be_in_room(self):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")


Expand Down Expand Up @@ -153,7 +154,7 @@ def _make_join(self, user_id) -> JsonDict:
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body

def test_send_join(self):
Expand All @@ -171,7 +172,7 @@ def test_send_join(self):
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# we should get complete room state back
returned_state = [
Expand Down Expand Up @@ -226,7 +227,7 @@ def test_send_join_partial_state(self):
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)

# expect a reduced room state
returned_state = [
Expand Down
5 changes: 3 additions & 2 deletions tests/federation/transport/test_knocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from http import HTTPStatus
from typing import Dict, List

from synapse.api.constants import EventTypes, JoinRules, Membership
Expand Down Expand Up @@ -255,7 +256,7 @@ def test_room_state_returned_when_knocking(self):
RoomVersions.V7.identifier,
),
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)

# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
Expand Down Expand Up @@ -293,7 +294,7 @@ def test_room_state_returned_when_knocking(self):
% (room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
self.assertEqual(200, channel.code, channel.result)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)

# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]
Expand Down
41 changes: 21 additions & 20 deletions tests/handlers/test_password_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests for the password_auth_provider interface"""

from http import HTTPStatus
from typing import Any, Type, Union
from unittest.mock import Mock

Expand Down Expand Up @@ -188,14 +189,14 @@ def password_only_auth_provider_login_test_body(self):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()

# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
Expand All @@ -204,7 +205,7 @@ def password_only_auth_provider_login_test_body(self):
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
Expand Down Expand Up @@ -258,10 +259,10 @@ def local_user_fallback_login_test_body(self):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)

channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])

@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
Expand Down Expand Up @@ -382,7 +383,7 @@ def password_auth_disabled_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_password.assert_not_called()

@override_config(legacy_providers_config(LegacyCustomAuthProvider))
Expand All @@ -406,14 +407,14 @@ def custom_auth_provider_login_test_body(self):

# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()

mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
Expand All @@ -427,7 +428,7 @@ def custom_auth_provider_login_test_body(self):
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
Expand Down Expand Up @@ -510,7 +511,7 @@ def custom_auth_provider_callback_test_body(self):
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
Expand Down Expand Up @@ -549,7 +550,7 @@ def custom_auth_password_disabled_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()

@override_config(
Expand Down Expand Up @@ -584,7 +585,7 @@ def custom_auth_password_disabled_localdb_enabled_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()

@override_config(
Expand Down Expand Up @@ -615,7 +616,7 @@ def password_custom_auth_password_disabled_login_test_body(self):

# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()

Expand Down Expand Up @@ -646,13 +647,13 @@ def password_custom_auth_password_disabled_ui_auth_test_body(self):
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
tok1 = channel.json_body["access_token"]

channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)

# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
Expand Down Expand Up @@ -721,7 +722,7 @@ def custom_auth_no_local_user_fallback_test_body(self):
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)

def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
Expand Down Expand Up @@ -884,7 +885,7 @@ def _test_3pid_allowed(self, username: str, registration: bool):
},
access_token=tok,
)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
Expand All @@ -906,7 +907,7 @@ def _test_3pid_allowed(self, username: str, registration: bool):
},
access_token=tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertIn("sid", channel.json_body)

m.assert_called_once_with("email", "bar@test.com", registration)
Expand Down Expand Up @@ -949,12 +950,12 @@ def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, 200, channel.json_body)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body

def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"]

def _send_password_login(self, user: str, password: str) -> FakeChannel:
Expand Down
16 changes: 8 additions & 8 deletions tests/rest/admin/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,7 @@ def test_create_server_admin(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def test_create_user(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
Expand Down Expand Up @@ -1512,7 +1512,7 @@ def test_create_user_mau_limit_reached_active_admin(self) -> None:
content={"password": "abc123", "admin": False},
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])

Expand Down Expand Up @@ -1550,7 +1550,7 @@ def test_create_user_mau_limit_reached_passive_admin(self) -> None:
)

# Admin user is not blocked by mau anymore
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])

Expand Down Expand Up @@ -1585,7 +1585,7 @@ def test_create_user_email_notif_for_new_users(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
Expand Down Expand Up @@ -1626,7 +1626,7 @@ def test_create_user_email_no_notif_for_new_users(self) -> None:
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
Expand Down Expand Up @@ -1666,7 +1666,7 @@ def test_create_user_email_notif_for_new_users_with_msisdn_threepid(self) -> Non
content=body,
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
Expand Down Expand Up @@ -2407,7 +2407,7 @@ def test_accidental_deactivation_prevention(self) -> None:
content={"password": "abc123"},
)

self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])

Expand Down
Loading