Skip to content

Block too many login attempts #5811

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,35 @@

from fastapi import Request
from fastapi.security import APIKeyHeader
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
from starlette.authentication import AuthCredentials, BaseUser

from argilla_server.constants import API_KEY_HEADER_NAME
from argilla_server.contexts import accounts
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.authentication.db.login_backend import LoginAuthenticationBackend


class APIKeyAuthenticationBackend(AuthenticationBackend):
class APIKeyAuthenticationBackend(LoginAuthenticationBackend):
"""Authentication backend for API Key authentication"""

scheme = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)

async def authenticate(self, request: Request) -> Optional[Tuple[AuthCredentials, BaseUser]]:
"""Authenticate the user using the API Key header"""
api_key: str = await self.scheme(request)
client_ip = request.client.host
if not api_key:
return None
is_locked = self.check_lockout(client_ip)
if is_locked:
return None

db = request.state.db
user = await accounts.get_user_by_api_key(db, api_key=api_key)
if not user:
self.increase_lockout(client_ip)
return None
self.clear_lockout(client_ip)

return AuthCredentials(), UserInfo(
username=user.username, name=user.first_name, role=user.role, identity=str(user.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,41 @@

from fastapi import Request
from fastapi.security import HTTPBearer
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
from starlette.authentication import AuthCredentials, BaseUser

from argilla_server.contexts import accounts
from argilla_server.security.authentication.jwt import JWT
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.authentication.db.login_backend import LoginAuthenticationBackend


class BearerTokenAuthenticationBackend(AuthenticationBackend):
class BearerTokenAuthenticationBackend(LoginAuthenticationBackend):
"""Authenticate the user using the username and password Bearer header"""

scheme = HTTPBearer(auto_error=False)

async def authenticate(self, request: Request) -> typing.Optional[typing.Tuple[AuthCredentials, BaseUser]]:
"""Authenticate the user using the username and password Bearer header"""
credentials = await self.scheme(request)
client_ip = request.client.host
if not credentials:
return None

token = credentials.credentials
username = JWT.decode(token).get("username")
is_locked = self.check_lockout(username)
is_locked_ip = self.check_lockout(client_ip)
if is_locked or is_locked_ip:
return None

db = request.state.db
user = await accounts.get_user_by_username(db, username)
if not user:
self.increase_lockout(username)
self.increase_lockout(client_ip)
return None

self.clear_lockout(username)
self.clear_lockout(client_ip)
return AuthCredentials(), UserInfo(
username=user.username, name=user.first_name, role=user.role, identity=str(user.id)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from starlette.authentication import AuthenticationBackend
from argilla_server.jobs.queues import REDIS_CONNECTION

MAX_ATTEMPTS = 5
LOCKOUT_TIME = 300 # 5 minutes


class LoginAuthenticationBackend(AuthenticationBackend):
"""
Authentication backend which locks the user out after a certain amount
of wrong attempts to login
"""

def __init__(self):
self.redis = REDIS_CONNECTION

async def check_lockout(self, credential_key: str) -> bool:
"""Check if credential key (username or API key) is locked out"""
key = f"failed_auth:{credential_key}"
attempts = await self.redis.get(key)
attempts = int(attempts or 0)
if attempts >= MAX_ATTEMPTS:
return False
return True

async def increase_lockout(self, credential_key: str) -> None:
"""Increment failed attempts"""
key = f"failed_auth:{credential_key}"
await self.redis.incr(key)

# Ensure expiration is set after first failure
ttl = await self.redis.ttl(key)
if ttl == -1: # Key exists but no expiration set
await self.redis.expire(key, LOCKOUT_TIME)

async def clear_lockout(self, credential_key: str) -> None:
"""Reset failed attempts on successful authentication"""
key = f"failed_auth:{credential_key}"
await self.redis.delete(key)