From 963f4309fe29206f3ba92b493e922280feea30ed Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 30 Mar 2021 12:06:09 +0100 Subject: [PATCH] Make RateLimiter class check for ratelimit overrides (#9711) This should fix a class of bug where we forget to check if e.g. the appservice shouldn't be ratelimited. We also check the `ratelimit_override` table to check if the user has ratelimiting disabled. That table is really only meant to override the event sender ratelimiting, so we don't use any values from it (as they might not make sense for different rate limits), but we do infer that if ratelimiting is disabled for the user we should disabled all ratelimits. Fixes #9663 --- changelog.d/9711.bugfix | 1 + synapse/api/ratelimiting.py | 100 ++++++++------ synapse/federation/federation_server.py | 5 +- synapse/handlers/_base.py | 14 +- synapse/handlers/auth.py | 24 ++-- synapse/handlers/devicemessage.py | 5 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 12 +- synapse/handlers/register.py | 6 +- synapse/handlers/room_member.py | 23 ++-- synapse/replication/http/register.py | 2 +- synapse/rest/client/v1/login.py | 14 +- synapse/rest/client/v2_alpha/account.py | 10 +- synapse/rest/client/v2_alpha/register.py | 8 +- synapse/server.py | 1 + tests/api/test_ratelimiting.py | 168 +++++++++++++++-------- 16 files changed, 241 insertions(+), 154 deletions(-) create mode 100644 changelog.d/9711.bugfix diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix new file mode 100644 index 000000000000..4ca3438d46b4 --- /dev/null +++ b/changelog.d/9711.bugfix @@ -0,0 +1 @@ +Fix recently added ratelimits to correctly honour the application service `rate_limited` flag. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index c3f07bc1a3e8..2244b8a34062 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -31,10 +32,13 @@ class Ratelimiter: burst_count: How many actions that can be performed before being limited. """ - def __init__(self, clock: Clock, rate_hz: float, burst_count: int): + def __init__( + self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int + ): self.clock = clock self.rate_hz = rate_hz self.burst_count = burst_count + self.store = store # A ordered dictionary keeping track of actions, when they were last # performed and how often. Each entry is a mapping from a key of arbitrary type @@ -46,45 +50,10 @@ def __init__(self, clock: Clock, rate_hz: float, burst_count: int): OrderedDict() ) # type: OrderedDict[Hashable, Tuple[float, int, float]] - def can_requester_do_action( - self, - requester: Requester, - rate_hz: Optional[float] = None, - burst_count: Optional[int] = None, - update: bool = True, - _time_now_s: Optional[int] = None, - ) -> Tuple[bool, float]: - """Can the requester perform the action? - - Args: - requester: The requester to key off when rate limiting. The user property - will be used. - rate_hz: The long term number of actions that can be performed in a second. - Overrides the value set during instantiation if set. - burst_count: How many actions that can be performed before being limited. - Overrides the value set during instantiation if set. - update: Whether to count this check as performing the action - _time_now_s: The current time. Optional, defaults to the current time according - to self.clock. Only used by tests. - - Returns: - A tuple containing: - * A bool indicating if they can perform the action now - * The reactor timestamp for when the action can be performed next. - -1 if rate_hz is less than or equal to zero - """ - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return True, -1.0 - - return self.can_do_action( - requester.user.to_string(), rate_hz, burst_count, update, _time_now_s - ) - - def can_do_action( + async def can_do_action( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -92,9 +61,16 @@ def can_do_action( ) -> Tuple[bool, float]: """Can the entity (e.g. user or IP address) perform the action? + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: The key we should use when rate limiting. Can be a user ID - (when sending events), an IP address, etc. + requester: The requester that is doing the action, if any. Used to check + if the user has ratelimits disabled in the database. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -109,6 +85,30 @@ def can_do_action( * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + + if requester: + # Disable rate limiting of users belonging to any AS that is configured + # not to be rate limited in its registration file (rate_limited: true|false). + if requester.app_service and not requester.app_service.is_rate_limited(): + return True, -1.0 + + # Check if ratelimiting has been disabled for the user. + # + # Note that we don't use the returned rate/burst count, as the table + # is specifically for the event sending ratelimiter. Instead, we + # only use it to (somewhat cheekily) infer whether the user should + # be subject to any rate limiting or not. + override = await self.store.get_ratelimit_for_user( + requester.authenticated_entity + ) + if override and not override.messages_per_second: + return True, -1.0 + # Override default values if set time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() rate_hz = rate_hz if rate_hz is not None else self.rate_hz @@ -175,9 +175,10 @@ def _prune_message_counts(self, time_now_s: int): else: del self.actions[key] - def ratelimit( + async def ratelimit( self, - key: Hashable, + requester: Optional[Requester], + key: Optional[Hashable] = None, rate_hz: Optional[float] = None, burst_count: Optional[int] = None, update: bool = True, @@ -185,8 +186,16 @@ def ratelimit( ): """Checks if an action can be performed. If not, raises a LimitExceededError + Checks if the user has ratelimiting disabled in the database by looking + for null/zero values in the `ratelimit_override` table. (Non-zero + values aren't honoured, as they're specific to the event sending + ratelimiter, rather than all ratelimiters) + Args: - key: An arbitrary key used to classify an action + requester: The requester that is doing the action, if any. Used to check for + if the user has ratelimits disabled. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. rate_hz: The long term number of actions that can be performed in a second. Overrides the value set during instantiation if set. burst_count: How many actions that can be performed before being limited. @@ -201,7 +210,8 @@ def ratelimit( """ time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() - allowed, time_allowed = self.can_do_action( + allowed, time_allowed = await self.can_do_action( + requester, key, rate_hz=rate_hz, burst_count=burst_count, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d84e362070d9..71cb120ef76b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -870,6 +870,7 @@ def __init__(self, hs: "HomeServer"): # A rate limiter for incoming room key requests per origin. self._room_key_request_rate_limiter = Ratelimiter( + store=hs.get_datastore(), clock=self.clock, rate_hz=self.config.rc_key_requests.per_second, burst_count=self.config.rc_key_requests.burst_count, @@ -930,7 +931,9 @@ async def on_edu(self, edu_type: str, origin: str, content: dict): # the limit, drop them. if ( edu_type == EduTypes.RoomKeyRequest - and not self._room_key_request_rate_limiter.can_do_action(origin) + and not await self._room_key_request_rate_limiter.can_do_action( + None, origin + ) ): return diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index aade2c4a3ad4..fb899aa90d4d 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -49,7 +49,7 @@ def __init__(self, hs: "HomeServer"): # The rate_hz and burst_count are overridden on a per-user basis self.request_ratelimiter = Ratelimiter( - clock=self.clock, rate_hz=0, burst_count=0 + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 ) self._rc_message = self.hs.config.rc_message @@ -57,6 +57,7 @@ def __init__(self, hs: "HomeServer"): # by the presence of rate limits in the config if self.hs.config.rc_admin_redaction: self.admin_redaction_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_admin_redaction.per_second, burst_count=self.hs.config.rc_admin_redaction.burst_count, @@ -91,11 +92,6 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False): if app_service is not None: return # do not ratelimit app service senders - # Disable rate limiting of users belonging to any AS that is configured - # not to be rate limited in its registration file (rate_limited: true|false). - if requester.app_service and not requester.app_service.is_rate_limited(): - return - messages_per_second = self._rc_message.per_second burst_count = self._rc_message.burst_count @@ -113,11 +109,11 @@ async def ratelimit(self, requester, update=True, is_admin_redaction=False): if is_admin_redaction and self.admin_redaction_ratelimiter: # If we have separate config for admin redactions, use a separate # ratelimiter as to not have user_ids clash - self.admin_redaction_ratelimiter.ratelimit(user_id, update=update) + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) else: # Override rate and burst count per-user - self.request_ratelimiter.ratelimit( - user_id, + await self.request_ratelimiter.ratelimit( + requester, rate_hz=messages_per_second, burst_count=burst_count, update=update, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index d537ea813785..08e413bc98e0 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -238,6 +238,7 @@ def __init__(self, hs: "HomeServer"): # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -248,6 +249,7 @@ def __init__(self, hs: "HomeServer"): # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_failed_attempts.per_second, burst_count=self.hs.config.rc_login_failed_attempts.burst_count, @@ -352,7 +354,7 @@ async def validate_user_via_ui_auth( requester_user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts - self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False) + await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False) # build a list of supported flows supported_ui_auth_types = await self._get_available_ui_auth_types( @@ -373,7 +375,9 @@ def get_new_session_data() -> JsonDict: ) except LoginError: # Update the ratelimiter to say we failed (`can_do_action` doesn't raise). - self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id) + await self._failed_uia_attempts_ratelimiter.can_do_action( + requester, + ) raise # find the completed login type @@ -982,8 +986,8 @@ async def validate_login( # We also apply account rate limiting using the 3PID as a key, as # otherwise using 3PID bypasses the ratelimiting based on user ID. if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - (medium, address), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, (medium, address), update=False ) # Check for login providers that support 3pid login types @@ -1016,8 +1020,8 @@ async def validate_login( # this code path, which is fine as then the per-user ratelimit # will kick in below. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - (medium, address) + await self._failed_login_attempts_ratelimiter.can_do_action( + None, (medium, address) ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) @@ -1039,8 +1043,8 @@ async def validate_login( # Check if we've hit the failed ratelimit (but don't update it) if ratelimit: - self._failed_login_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + await self._failed_login_attempts_ratelimiter.ratelimit( + None, qualified_user_id.lower(), update=False ) try: @@ -1051,8 +1055,8 @@ async def validate_login( # exception and masking the LoginError. The actual ratelimiting # should have happened above. if ratelimit: - self._failed_login_attempts_ratelimiter.can_do_action( - qualified_user_id.lower() + await self._failed_login_attempts_ratelimiter.can_do_action( + None, qualified_user_id.lower() ) raise diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index eb547743be9f..5ee48be6ffa7 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -81,6 +81,7 @@ def __init__(self, hs: "HomeServer"): ) self._ratelimiter = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.rc_key_requests.per_second, burst_count=hs.config.rc_key_requests.burst_count, @@ -191,8 +192,8 @@ async def send_device_message( if ( message_type == EduTypes.RoomKeyRequest and user_id != sender_user_id - and self._ratelimiter.can_do_action( - (sender_user_id, requester.device_id) + and await self._ratelimiter.can_do_action( + requester, (sender_user_id, requester.device_id) ) ): continue diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 598a66f74cf4..3ebee38ebe13 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1711,7 +1711,7 @@ async def on_invite_request( member_handler = self.hs.get_room_member_handler() # We don't rate limit based on room ID, as that should be done by # sending server. - member_handler.ratelimit_invite(None, event.state_key) + await member_handler.ratelimit_invite(None, None, event.state_key) # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 5f346f6d6d28..d89fa5fb305d 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -61,17 +61,19 @@ def __init__(self, hs): # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) self._3pid_validation_ratelimiter_address = Ratelimiter( + store=self.store, clock=hs.get_clock(), rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, ) - def ratelimit_request_token_requests( + async def ratelimit_request_token_requests( self, request: SynapseRequest, medium: str, @@ -85,8 +87,12 @@ def ratelimit_request_token_requests( address: The actual threepid ID, e.g. the phone number or email address """ - self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) - self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + await self._3pid_validation_ratelimiter_ip.ratelimit( + None, (medium, request.getClientIP()) + ) + await self._3pid_validation_ratelimiter_address.ratelimit( + None, (medium, address) + ) async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0fc2bf15d520..9701b76d0f91 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -204,7 +204,7 @@ async def register_user( Raises: SynapseError if there was a problem registering. """ - self.check_registration_ratelimit(address) + await self.check_registration_ratelimit(address) result = await self.spam_checker.check_registration_for_spam( threepid, @@ -583,7 +583,7 @@ def check_user_id_not_appservice_exclusive( errcode=Codes.EXCLUSIVE, ) - def check_registration_ratelimit(self, address: Optional[str]) -> None: + async def check_registration_ratelimit(self, address: Optional[str]) -> None: """A simple helper method to check whether the registration rate limit has been hit for a given IP address @@ -597,7 +597,7 @@ def check_registration_ratelimit(self, address: Optional[str]) -> None: if not address: return - self.ratelimiter.ratelimit(address) + await self.ratelimiter.ratelimit(None, address) async def register_with_store( self, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4d20ed835764..1cf12f3255bc 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -75,22 +75,26 @@ def __init__(self, hs: "HomeServer"): self.allow_per_room_profiles = self.config.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, ) self._join_rate_limiter_remote = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) self._invites_per_room_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, ) self._invites_per_user_limiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, @@ -159,15 +163,20 @@ async def _user_left_room(self, target: UserID, room_id: str) -> None: async def forget(self, user: UserID, room_id: str) -> None: raise NotImplementedError() - def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str): + async def ratelimit_invite( + self, + requester: Optional[Requester], + room_id: Optional[str], + invitee_user_id: str, + ): """Ratelimit invites by room and by target user. If room ID is missing then we just rate limit by target user. """ if room_id: - self._invites_per_room_limiter.ratelimit(room_id) + await self._invites_per_room_limiter.ratelimit(requester, room_id) - self._invites_per_user_limiter.ratelimit(invitee_user_id) + await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) async def _local_membership_update( self, @@ -237,7 +246,7 @@ async def _local_membership_update( ( allowed, time_allowed, - ) = self._join_rate_limiter_local.can_requester_do_action(requester) + ) = await self._join_rate_limiter_local.can_do_action(requester) if not allowed: raise LimitExceededError( @@ -421,9 +430,7 @@ async def update_membership_locked( if effective_membership_state == Membership.INVITE: target_id = target.to_string() if ratelimit: - # Don't ratelimit application services. - if not requester.app_service or requester.app_service.is_rate_limited(): - self.ratelimit_invite(room_id, target_id) + await self.ratelimit_invite(requester, room_id, target_id) # block any attempts to invite the server notices mxid if target_id == self._server_notices_mxid: @@ -534,7 +541,7 @@ async def update_membership_locked( ( allowed, time_allowed, - ) = self._join_rate_limiter_remote.can_requester_do_action( + ) = await self._join_rate_limiter_remote.can_do_action( requester, ) diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py index d005f3876717..73d747785420 100644 --- a/synapse/replication/http/register.py +++ b/synapse/replication/http/register.py @@ -77,7 +77,7 @@ async def _serialize_payload( async def _handle_request(self, request, user_id): content = parse_json_object_from_request(request) - self.registration_handler.check_registration_ratelimit(content["address"]) + await self.registration_handler.check_registration_ratelimit(content["address"]) await self.registration_handler.register_with_store( user_id=user_id, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index e4c352f572a2..3151e72d4f19 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -74,11 +74,13 @@ def __init__(self, hs: "HomeServer"): self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, ) self._account_ratelimiter = Ratelimiter( + store=hs.get_datastore(), clock=hs.get_clock(), rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, @@ -141,20 +143,22 @@ async def on_POST(self, request: SynapseRequest): appservice = self.auth.get_appservice_by_req(request) if appservice.is_rate_limited(): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit( + None, request.getClientIP() + ) result = await self._do_appservice_login(login_submission, appservice) elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_token_login(login_submission) else: - self._address_ratelimiter.ratelimit(request.getClientIP()) + await self._address_ratelimiter.ratelimit(None, request.getClientIP()) result = await self._do_other_login(login_submission) except KeyError: raise SynapseError(400, "Missing JSON keys.") @@ -258,7 +262,7 @@ async def _complete_login( # too often. This happens here rather than before as we don't # necessarily know the user before now. if ratelimit: - self._account_ratelimiter.ratelimit(user_id.lower()) + await self._account_ratelimiter.ratelimit(None, user_id.lower()) if create_non_existent_users: canonical_uid = await self.auth_handler.check_user_exists(user_id) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index c2ba790babde..411fb57c473d 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -103,7 +103,9 @@ async def on_POST(self, request): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to @@ -387,7 +389,9 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) if next_link: # Raise if the provided next_link value isn't valid @@ -468,7 +472,7 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 8f68d8dfc8fa..c212da0cb29c 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,7 +126,9 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests(request, "email", email) + await self.identity_handler.ratelimit_request_token_requests( + request, "email", email + ) existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email @@ -208,7 +210,7 @@ async def on_POST(self, request): Codes.THREEPID_DENIED, ) - self.identity_handler.ratelimit_request_token_requests( + await self.identity_handler.ratelimit_request_token_requests( request, "msisdn", msisdn ) @@ -406,7 +408,7 @@ async def on_POST(self, request): client_addr = request.getClientIP() - self.ratelimiter.ratelimit(client_addr, update=False) + await self.ratelimiter.ratelimit(None, client_addr, update=False) kind = b"user" if b"kind" in request.args: diff --git a/synapse/server.py b/synapse/server.py index e85b9391faae..e42f7b1a1876 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -329,6 +329,7 @@ def get_distributor(self) -> Distributor: @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: return Ratelimiter( + store=self.get_datastore(), clock=self.get_clock(), rate_hz=self.config.rc_registration.per_second, burst_count=self.config.rc_registration.burst_count, diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483418192c4b..fa96ba07a590 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -5,38 +5,25 @@ from tests import unittest -class TestRatelimiter(unittest.TestCase): +class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_via_can_do_action(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0) - self.assertTrue(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5) - self.assertFalse(allowed) - self.assertEquals(10.0, time_allowed) - - allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10) - self.assertTrue(allowed) - self.assertEquals(20.0, time_allowed) - - def test_allowed_user_via_can_requester_do_action(self): - user_requester = create_requester("@user:example.com") - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - user_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, key="test_id", _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -51,21 +38,23 @@ def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertFalse(allowed) self.assertEquals(10.0, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(20.0, time_allowed) @@ -80,73 +69,89 @@ def test_allowed_appservice_via_can_requester_do_action(self): ) as_requester = create_requester("@user:example.com", app_service=appservice) - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=0 + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=0) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=5 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=5) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) - allowed, time_allowed = limiter.can_requester_do_action( - as_requester, _time_now_s=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(as_requester, _time_now_s=10) ) self.assertTrue(allowed) self.assertEquals(-1, time_allowed) def test_allowed_via_ratelimit(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=0) + self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0)) # Should raise with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key="test_id", _time_now_s=5) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=5) + ) self.assertEqual(context.exception.retry_after_ms, 5000) # Shouldn't raise - limiter.ratelimit(key="test_id", _time_now_s=10) + self.get_success_or_raise( + limiter.ratelimit(None, key="test_id", _time_now_s=10) + ) def test_allowed_via_can_do_action_and_overriding_parameters(self): """Test that we can override options of can_do_action that would otherwise fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=0, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=0, + ) ) self.assertTrue(allowed) self.assertEqual(10.0, time_allowed) # Second attempt, 1s later, will fail - allowed, time_allowed = limiter.can_do_action( - ("test_id",), - _time_now_s=1, + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action( + None, + ("test_id",), + _time_now_s=1, + ) ) self.assertFalse(allowed) self.assertEqual(10.0, time_allowed) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, rate_hz=10.0 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0) ) self.assertTrue(allowed) self.assertEqual(1.1, time_allowed) # Similarly if we allow a burst of 10 actions - allowed, time_allowed = limiter.can_do_action( - ("test_id",), _time_now_s=1, burst_count=10 + allowed, time_allowed = self.get_success_or_raise( + limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10) ) self.assertTrue(allowed) self.assertEqual(1.0, time_allowed) @@ -156,29 +161,72 @@ def test_allowed_via_ratelimit_and_overriding_parameters(self): fail an action """ # Create a Ratelimiter with a very low allowed rate_hz and burst_count - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) # First attempt should be allowed - limiter.ratelimit(key=("test_id",), _time_now_s=0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=0) + ) # Second attempt, 1s later, will fail with self.assertRaises(LimitExceededError) as context: - limiter.ratelimit(key=("test_id",), _time_now_s=1) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1) + ) self.assertEqual(context.exception.retry_after_ms, 9000) # But, if we allow 10 actions/sec for this request, we should be allowed # to continue. - limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0) + ) # Similarly if we allow a burst of 10 actions - limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10) + self.get_success_or_raise( + limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10) + ) def test_pruning(self): - limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1) - limiter.can_do_action(key="test_id_1", _time_now_s=0) + limiter = Ratelimiter( + store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1 + ) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_1", _time_now_s=0) + ) self.assertIn("test_id_1", limiter.actions) - limiter.can_do_action(key="test_id_2", _time_now_s=10) + self.get_success_or_raise( + limiter.can_do_action(None, key="test_id_2", _time_now_s=10) + ) self.assertNotIn("test_id_1", limiter.actions) + + def test_db_user_override(self): + """Test that users that have ratelimiting disabled in the DB aren't + ratelimited. + """ + store = self.hs.get_datastore() + + user_id = "@user:test" + requester = create_requester(user_id) + + self.get_success( + store.db_pool.simple_insert( + table="ratelimit_override", + values={ + "user_id": user_id, + "messages_per_second": None, + "burst_count": None, + }, + desc="test_db_user_override", + ) + ) + + limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1) + + # Shouldn't raise + for _ in range(20): + self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))