Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use HTTPStatus constants in place of literals in tests #13298

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Use HTTPStatus constants in place of literals in tests.
  • Loading branch information
dklimpel committed Jul 15, 2022
commit 7e8ebbc5b2ab33b100a17550cc5ee20f634b4831
81 changes: 45 additions & 36 deletions tests/rest/client/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import time
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
Expand Down Expand Up @@ -134,10 +135,12 @@ def test_POST_ratelimiting_per_address(self) -> None:
channel = self.make_request(b"POST", LOGIN_URL, params)

if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result)
self.assertEqual(
channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result
)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
Expand All @@ -152,7 +155,7 @@ def test_POST_ratelimiting_per_address(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

@override_config(
{
Expand All @@ -179,10 +182,12 @@ def test_POST_ratelimiting_per_account(self) -> None:
channel = self.make_request(b"POST", LOGIN_URL, params)

if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result)
self.assertEqual(
channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result
)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
Expand All @@ -197,7 +202,7 @@ def test_POST_ratelimiting_per_account(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

@override_config(
{
Expand All @@ -224,10 +229,14 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
channel = self.make_request(b"POST", LOGIN_URL, params)

if i == 5:
self.assertEqual(channel.result["code"], b"429", channel.result)
self.assertEqual(
channel.code, HTTPStatus.TOO_MANY_REQUESTS, msg=channel.result
)
retry_after_ms = int(channel.json_body["retry_after_ms"])
else:
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(
channel.code, HTTPStatus.FORBIDDEN, msg=channel.result
)

# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
# than 1min.
Expand All @@ -242,15 +251,15 @@ def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)

@override_config({"session_lifetime": "24h"})
def test_soft_logout(self) -> None:
self.register_user("kermit", "monkey")

# we shouldn't be able to make requests without an access token
channel = self.make_request(b"GET", TEST_URL)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")

# log in as normal
Expand Down Expand Up @@ -354,7 +363,7 @@ def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:

# Now try to hard logout this session
channel = self.make_request(b"POST", "/logout", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
Expand All @@ -380,7 +389,7 @@ def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(

# Now try to hard log out all of the user's sessions
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

def test_login_with_overly_long_device_id_fails(self) -> None:
self.register_user("mickey", "cheese")
Expand Down Expand Up @@ -878,17 +887,17 @@ def jwt_login(self, *args: Any) -> FakeChannel:
def test_login_jwt_valid_registered(self) -> None:
self.register_user("kermit", "monkey")
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

def test_login_jwt_valid_unregistered(self) -> None:
channel = self.jwt_login({"sub": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")

def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -897,7 +906,7 @@ def test_login_jwt_invalid_signature(self) -> None:

def test_login_jwt_expired(self) -> None:
channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -907,7 +916,7 @@ def test_login_jwt_expired(self) -> None:
def test_login_jwt_not_before(self) -> None:
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -916,7 +925,7 @@ def test_login_jwt_not_before(self) -> None:

def test_login_no_sub(self) -> None:
channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")

Expand All @@ -925,12 +934,12 @@ def test_login_iss(self) -> None:
"""Test validating the issuer claim."""
# A valid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -939,7 +948,7 @@ def test_login_iss(self) -> None:

# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -949,20 +958,20 @@ def test_login_iss(self) -> None:
def test_login_iss_no_config(self) -> None:
"""Test providing an issuer claim without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
def test_login_aud(self) -> None:
"""Test validating the audience claim."""
# A valid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -971,7 +980,7 @@ def test_login_aud(self) -> None:

# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -981,7 +990,7 @@ def test_login_aud(self) -> None:
def test_login_aud_no_config(self) -> None:
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand All @@ -991,20 +1000,20 @@ def test_login_aud_no_config(self) -> None:
def test_login_default_sub(self) -> None:
"""Test reading user ID from the default subject claim."""
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
def test_login_custom_sub(self) -> None:
"""Test reading user ID from a custom subject claim."""
channel = self.jwt_login({"username": "frog"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@frog:test")

def test_login_no_token(self) -> None:
params = {"type": "org.matrix.login.jwt"}
channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")

Expand Down Expand Up @@ -1086,12 +1095,12 @@ def jwt_login(self, *args: Any) -> FakeChannel:

def test_login_jwt_valid(self) -> None:
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)
self.assertEqual(channel.json_body["user_id"], "@kermit:test")

def test_login_jwt_invalid_signature(self) -> None:
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
Expand Down Expand Up @@ -1152,7 +1161,7 @@ def test_login_appservice_user(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.service.token
)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

def test_login_appservice_user_bot(self) -> None:
"""Test that the appservice bot can use /login"""
Expand All @@ -1166,7 +1175,7 @@ def test_login_appservice_user_bot(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.service.token
)

self.assertEqual(channel.result["code"], b"200", channel.result)
self.assertEqual(channel.code, HTTPStatus.OK, msg=channel.result)

def test_login_appservice_wrong_user(self) -> None:
"""Test that non-as users cannot login with the as token"""
Expand All @@ -1180,7 +1189,7 @@ def test_login_appservice_wrong_user(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.service.token
)

self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)

def test_login_appservice_wrong_as(self) -> None:
"""Test that as users cannot login with wrong as token"""
Expand All @@ -1194,7 +1203,7 @@ def test_login_appservice_wrong_as(self) -> None:
b"POST", LOGIN_URL, params, access_token=self.another_service.token
)

self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, msg=channel.result)

def test_login_appservice_no_token(self) -> None:
"""Test that users must provide a token when using the appservice
Expand All @@ -1208,7 +1217,7 @@ def test_login_appservice_no_token(self) -> None:
}
channel = self.make_request(b"POST", LOGIN_URL, params)

self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, msg=channel.result)


@skip_unless(HAS_OIDC, "requires OIDC")
Expand Down
31 changes: 24 additions & 7 deletions tests/rest/client/test_redactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import List

from twisted.test.proto_helpers import MemoryReactor
Expand Down Expand Up @@ -67,7 +68,11 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
)

def _redact_event(
self, access_token: str, room_id: str, event_id: str, expect_code: int = 200
self,
access_token: str,
room_id: str,
event_id: str,
expect_code: int = HTTPStatus.OK,
) -> JsonDict:
"""Helper function to send a redaction event.

Expand All @@ -76,12 +81,12 @@ def _redact_event(
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)

channel = self.make_request("POST", path, content={}, access_token=access_token)
self.assertEqual(int(channel.result["code"]), expect_code)
self.assertEqual(channel.code, expect_code)
return channel.json_body

def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.code, HTTPStatus.OK)
room_sync = channel.json_body["rooms"]["join"][room_id]
return room_sync["timeline"]["events"]

Expand Down Expand Up @@ -117,7 +122,10 @@ def test_redact_event_as_normal(self) -> None:

# as a normal, try to redact the admin's event
self._redact_event(
self.other_access_token, self.room_id, admin_msg_id, expect_code=403
self.other_access_token,
self.room_id,
admin_msg_id,
expect_code=HTTPStatus.FORBIDDEN
)

# now try to redact our own event
Expand Down Expand Up @@ -153,7 +161,10 @@ def test_redact_nonexistent_event(self) -> None:

# ... but normals cannot
self._redact_event(
self.other_access_token, self.room_id, "$zzz", expect_code=404
self.other_access_token,
self.room_id,
"$zzz",
expect_code=HTTPStatus.NOT_FOUND,
)

# when we sync, we should see only the valid redaction
Expand All @@ -178,12 +189,18 @@ def test_redact_create_event(self) -> None:

# room moderators cannot send redactions for create events
self._redact_event(
self.mod_access_token, self.room_id, create_event_id, expect_code=403
self.mod_access_token,
self.room_id,
create_event_id,
expect_code=HTTPStatus.FORBIDDEN,
)

# and nor can normals
self._redact_event(
self.other_access_token, self.room_id, create_event_id, expect_code=403
self.other_access_token,
self.room_id,
create_event_id,
expect_code=HTTPStatus.FORBIDDEN,
)

def test_redact_event_as_moderator_ratelimit(self) -> None:
Expand Down
Loading