Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respond with 401 for requests with bad credentials #4126

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions api/api/docs/audio_docs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rest_framework.exceptions import (
AuthenticationFailed,
NotAuthenticated,
NotFound,
ValidationError,
Expand Down Expand Up @@ -75,7 +76,10 @@

By using this endpoint, you can obtain info about content providers such
as {fields_to_md(ProviderSerializer.Meta.fields)}.""",
res={200: (ProviderSerializer(many=True), audio_stats_200_example)},
res={
200: (ProviderSerializer(many=True), audio_stats_200_example),
401: (AuthenticationFailed, None),
},
eg=[audio_stats_curl],
)

Expand All @@ -87,6 +91,7 @@
{fields_to_md(AudioSerializer.Meta.fields)}""",
res={
200: (AudioSerializer, audio_detail_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, audio_detail_404_example),
},
eg=[audio_detail_curl],
Expand All @@ -100,6 +105,7 @@
{fields_to_md(AudioSerializer.Meta.fields)}.""",
res={
200: (AudioSerializer(many=True), audio_related_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, audio_related_404_example),
},
eg=[audio_related_curl],
Expand All @@ -109,18 +115,23 @@
res={
201: (AudioReportRequestSerializer, audio_complain_201_example),
400: (ValidationError, None),
401: (AuthenticationFailed, None),
},
eg=[audio_complain_curl],
)

thumbnail = extend_schema(
parameters=[MediaThumbnailRequestSerializer],
responses={200: OpenApiResponse(description="Thumbnail image")},
responses={
200: OpenApiResponse(description="Thumbnail image"),
401: AuthenticationFailed,
},
)

waveform = custom_extend_schema(
res={
200: (AudioWaveformSerializer, audio_waveform_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, audio_waveform_404_example),
},
eg=[audio_waveform_curl],
Expand Down
82 changes: 75 additions & 7 deletions api/api/docs/base_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

from django.conf import settings
from rest_framework.exceptions import (
APIException,
NotFound,
ValidationError,
)

from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.openapi import AutoSchema
from drf_spectacular.utils import (
OpenApiExample,
Expand All @@ -31,6 +34,77 @@ def fields_to_md(field_names):
return f"{all_but_last} and `{last}`"


class APIExceptionOpenApiSerializerExtension(OpenApiSerializerExtension):
target_class = APIException
match_subclasses = True

@classmethod
def _get_detail(cls, target):
return getattr(target, "detail", target.default_detail)

def get_name(self, *args):
cls = self.target if isinstance(self.target, type) else self.target.__class__
return cls.__name__

def map_serializer(self, *args):
cls = self.target if isinstance(self.target, type) else self.target.__class__

detail_string = {
"type": "string",
"description": "A description of what went wrong.",
}

if cls == ValidationError or issubclass(cls, ValidationError):
return {
"title": "ValidationError",
"type": "object",
"properties": {
"detail": {
"oneOf": [
detail_string,
{
"type": "object",
"additionalProperties": True,
},
]
}
},
}

return {
"title": cls.__name__,
"type": "object",
"properties": {"detail": detail_string},
}

@classmethod
def exception_example(cls, exception):
if exception == ValidationError:
return {"detail": {"<request parameter>": "<error details>"}}

return {"detail": cls._get_detail(exception)}


def get_examples(code, serializer, example):
if (
not example
and isinstance(serializer, type)
and issubclass(serializer, APIException)
):
example = APIExceptionOpenApiSerializerExtension.exception_example(serializer)
elif example:
example = example["application/json"]
else:
return []

return [
OpenApiExample(
http_responses[code],
value=example,
)
]


def custom_extend_schema(**kwargs):
extend_args = {}

Expand All @@ -51,13 +125,7 @@ def custom_extend_schema(**kwargs):
code: OpenApiResponse(
serializer,
description=http_responses[code],
examples=[
OpenApiExample(
http_responses[code], value=example["application/json"]
)
]
if example
else [],
examples=get_examples(code, serializer, example),
)
for code, (serializer, example) in responses.items()
}
Expand Down
26 changes: 22 additions & 4 deletions api/api/docs/image_docs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from rest_framework.exceptions import (
AuthenticationFailed,
NotAuthenticated,
NotFound,
ValidationError,
Expand Down Expand Up @@ -78,7 +79,10 @@

By using this endpoint, you can obtain info about content providers such
as {fields_to_md(ProviderSerializer.Meta.fields)}.""",
res={200: (ProviderSerializer(many=True), image_stats_200_example)},
res={
200: (ProviderSerializer(many=True), image_stats_200_example),
401: (AuthenticationFailed, None),
},
eg=[image_stats_curl],
)

Expand All @@ -90,6 +94,7 @@
{fields_to_md(ImageSerializer.Meta.fields)}""",
res={
200: (ImageSerializer, image_detail_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, image_detail_404_example),
},
eg=[image_detail_curl],
Expand All @@ -103,6 +108,7 @@
{fields_to_md(ImageSerializer.Meta.fields)}.""",
res={
200: (ImageSerializer, image_related_200_example),
401: (AuthenticationFailed, None),
404: (NotFound, image_related_404_example),
},
eg=[image_related_curl],
Expand All @@ -111,24 +117,36 @@
report = custom_extend_schema(
res={
201: (ImageReportRequestSerializer, image_complain_201_example),
401: (AuthenticationFailed, None),
400: (ValidationError, None),
},
eg=[image_complain_curl],
)

thumbnail = extend_schema(
parameters=[MediaThumbnailRequestSerializer],
responses={200: OpenApiResponse(description="Thumbnail image"), 404: NotFound},
responses={
200: OpenApiResponse(description="Thumbnail image"),
404: NotFound,
401: AuthenticationFailed,
},
)

oembed = custom_extend_schema(
params=OembedRequestSerializer,
res={
200: (OembedSerializer, image_oembed_200_example),
404: (NotFound, image_oembed_404_example),
400: (ValidationError, image_oembed_400_example),
401: (AuthenticationFailed, None),
404: (NotFound, image_oembed_404_example),
},
eg=[image_oembed_curl],
)

watermark = extend_schema(deprecated=True, responses={404: NotFound})
watermark = extend_schema(
deprecated=True,
responses={
401: AuthenticationFailed,
404: NotFound,
},
)
5 changes: 2 additions & 3 deletions api/api/docs/oauth2_docs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from rest_framework.exceptions import (
APIException,
NotAuthenticated,
PermissionDenied,
ValidationError,
)

from api.docs.base_docs import custom_extend_schema
from api.examples import (
auth_key_info_200_example,
auth_key_info_403_example,
auth_key_info_curl,
auth_register_201_example,
auth_register_curl,
Expand All @@ -30,6 +28,7 @@
res={
201: (OAuth2ApplicationSerializer, auth_register_201_example),
400: (ValidationError, None),
401: ({"type": "object", "properties": {"error": {"type": "string"}}}, None),
429: (
APIException("Request was throttled. Expected available in 1 second.", 429),
None,
Expand All @@ -42,7 +41,7 @@
operation_id="key_info",
res={
200: (OAuth2KeyInfoSerializer, auth_key_info_200_example),
403: (PermissionDenied, auth_key_info_403_example),
401: (NotAuthenticated, None),
429: (
APIException("Request was throttled. Expected available in 1 second.", 429),
None,
Expand Down
1 change: 0 additions & 1 deletion api/api/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
)
from api.examples.oauth2_responses import (
auth_key_info_200_example,
auth_key_info_403_example,
auth_register_201_example,
auth_token_200_example,
)
Expand Down
2 changes: 0 additions & 2 deletions api/api/examples/oauth2_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,3 @@
"rate_limit_model": "enhanced",
}
}

auth_key_info_403_example = {"application/json": "Forbidden"}
23 changes: 12 additions & 11 deletions api/api/views/oauth2_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from django.core.cache import cache
from django.core.mail import send_mail
from django.db import DataError
from rest_framework.exceptions import APIException, PermissionDenied
from rest_framework.exceptions import APIException
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.views import APIView

from drf_spectacular.utils import extend_schema
from oauth2_provider.contrib.rest_framework.permissions import TokenHasScope
from oauth2_provider.generators import generate_client_secret
from oauth2_provider.views import TokenView as BaseTokenView
from redis.exceptions import ConnectionError
Expand All @@ -40,6 +41,8 @@ class InvalidCredentials(APIException):
@extend_schema(tags=["auth"])
class Register(APIView):
throttle_classes = (TenPerDay,)
# Registration implicitly does not require authentication
authentication_classes = ()

@register
def post(self, request, format=None):
Expand Down Expand Up @@ -150,6 +153,10 @@ def get(self, request, code, format=None):

@extend_schema(tags=["auth"])
class TokenView(APIView, BaseTokenView):
# Token view is pre-authentication
authentication_classes = ()
permission_classes = ()

@token
def post(self, request):
"""
Expand Down Expand Up @@ -178,6 +185,8 @@ def post(self, request):
@extend_schema(tags=["auth"])
class CheckRates(APIView):
throttle_classes = (OnePerSecond,)
permission_classes = (TokenHasScope,)
required_scopes = ("read",)

@key_info
def get(self, request: Request, format=None):
Expand All @@ -187,21 +196,13 @@ def get(self, request: Request, format=None):
You can use this endpoint to get information about your API key such as
`requests_this_minute`, `requests_today`, and `rate_limit_model`.

> ℹ️ **NOTE:** If you get a 403 Forbidden response, it means your access
> token has expired.
> ℹ️ **NOTE:** If you get a 401 Unauthorized, it means your token is invalid
> (malformed, non-existent, or expired).
"""

# TODO: Replace 403 responses with DRF `authentication_classes`.
if not request.auth or not hasattr(request.auth, "application"):
raise PermissionDenied("Forbidden", 403)

application: ThrottledApplication = request.auth.application

client_id = application.client_id

if not client_id:
raise PermissionDenied("Forbidden", 403)

throttle_type = application.rate_limit_model
throttle_key = "throttle_{scope}_{client_id}"
if throttle_type == "standard":
Expand Down
27 changes: 27 additions & 0 deletions api/conf/oauth2_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from rest_framework.exceptions import AuthenticationFailed

from drf_spectacular.authentication import TokenScheme
from oauth2_provider.contrib.rest_framework import (
OAuth2Authentication as BaseOAuth2Authentication,
)


class OAuth2Authentication(BaseOAuth2Authentication):
# Required by schema extension
keyword = "Bearer"

def authenticate(self, request):
result = super().authenticate(request)
if getattr(request, "oauth2_error", None):
# oauth2_error is only defined on requests that had errors
# it will be undefined or empty for anonymous requests and
# requests with valid credentials
# `request` is mutated by `super().authenticate`
raise AuthenticationFailed()

return result


class OAuth2OpenApiAuthenticationExtension(TokenScheme):
target_class = "conf.oauth2_extensions.OAuth2Authentication"
name = "Openverse API Token"
4 changes: 1 addition & 3 deletions api/conf/settings/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@
)

REST_FRAMEWORK = {
"DEFAULT_AUTHENTICATION_CLASSES": (
"oauth2_provider.contrib.rest_framework.OAuth2Authentication",
),
"DEFAULT_AUTHENTICATION_CLASSES": ("conf.oauth2_extensions.OAuth2Authentication",),
"DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.URLPathVersioning",
"DEFAULT_RENDERER_CLASSES": (
"rest_framework.renderers.JSONRenderer",
Expand Down
Loading