Skip to content

Throttle API requests based on user permissions #1909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vulnerabilities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vulnerabilities.models import get_purl_query_lookups
from vulnerabilities.severity_systems import EPSS
from vulnerabilities.severity_systems import SCORING_SYSTEMS
from vulnerabilities.throttling import StaffUserRateThrottle
from vulnerabilities.throttling import PermissionBasedUserRateThrottle
from vulnerabilities.utils import get_severity_range


Expand Down Expand Up @@ -471,7 +471,7 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = PackageSerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = PackageFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_queryset(self):
return super().get_queryset().with_is_vulnerable()
Expand Down Expand Up @@ -688,7 +688,7 @@ def get_queryset(self):
serializer_class = VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = VulnerabilityFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]


class CPEFilterSet(filters.FilterSet):
Expand Down
10 changes: 5 additions & 5 deletions vulnerabilities/api_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vulnerabilities.models import VulnerabilitySeverity
from vulnerabilities.models import Weakness
from vulnerabilities.models import get_purl_query_lookups
from vulnerabilities.throttling import StaffUserRateThrottle
from vulnerabilities.throttling import PermissionBasedUserRateThrottle


class SerializerExcludeFieldsMixin:
Expand Down Expand Up @@ -259,7 +259,7 @@ class V2PackageViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = "purl"
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = V2PackageFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle]

def get_queryset(self):
return super().get_queryset().with_is_vulnerable().prefetch_related("vulnerabilities")
Expand Down Expand Up @@ -345,7 +345,7 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet):
lookup_field = "vulnerability_id"
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = V2VulnerabilityFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle]

def get_queryset(self):
"""
Expand Down Expand Up @@ -381,7 +381,7 @@ class CPEViewSet(viewsets.ReadOnlyModelViewSet):
).distinct()
serializer_class = V2VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle]
filterset_class = CPEFilterSet

@action(detail=False, methods=["post"])
Expand Down Expand Up @@ -420,4 +420,4 @@ class AliasViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = V2VulnerabilitySerializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = AliasFilterSet
throttle_classes = [StaffUserRateThrottle, AnonRateThrottle]
throttle_classes = [PermissionBasedUserRateThrottle, AnonRateThrottle]
6 changes: 6 additions & 0 deletions vulnerabilities/api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from rest_framework.permissions import BasePermission
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.throttling import AnonRateThrottle

from vulnerabilities.models import CodeFix
from vulnerabilities.models import Package
Expand All @@ -32,6 +33,7 @@
from vulnerabilities.models import VulnerabilityReference
from vulnerabilities.models import VulnerabilitySeverity
from vulnerabilities.models import Weakness
from vulnerabilities.throttling import PermissionBasedUserRateThrottle


class WeaknessV2Serializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -134,6 +136,7 @@ class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Vulnerability.objects.all()
serializer_class = VulnerabilityV2Serializer
lookup_field = "vulnerability_id"
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_queryset(self):
queryset = super().get_queryset()
Expand Down Expand Up @@ -272,6 +275,7 @@ class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = PackageV2Serializer
filter_backends = (filters.DjangoFilterBackend,)
filterset_class = PackageV2FilterSet
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_queryset(self):
queryset = super().get_queryset()
Expand Down Expand Up @@ -599,6 +603,7 @@ class CodeFixViewSet(viewsets.ReadOnlyModelViewSet):

queryset = CodeFix.objects.all()
serializer_class = CodeFixSerializer
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_queryset(self):
"""
Expand Down Expand Up @@ -741,6 +746,7 @@ class PipelineScheduleV2ViewSet(CreateListRetrieveUpdateViewSet):
serializer_class = PipelineScheduleAPISerializer
lookup_field = "pipeline_id"
lookup_value_regex = r"[\w.]+"
throttle_classes = [AnonRateThrottle, PermissionBasedUserRateThrottle]

def get_serializer_class(self):
if self.action == "create":
Expand Down
24 changes: 24 additions & 0 deletions vulnerabilities/migrations/0093_alter_apiuser_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 4.2.22 on 2025-06-13 12:44

from django.db import migrations


class Migration(migrations.Migration):

dependencies = [
("vulnerabilities", "0092_pipelineschedule_pipelinerun"),
]

operations = [
migrations.AlterModelOptions(
name="apiuser",
options={
"permissions": [
("throttle_unrestricted", "Can make api requests without throttling limits"),
("throttle_18000_hour", "Can make 18000 api requests per hour"),
("throttle_14400_hour", "Can make 14400 api requests per hour"),
("throttle_3600_hour", "Can make 3600 api requests per hour"),
]
},
),
]
11 changes: 8 additions & 3 deletions vulnerabilities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from cwe2.mappings import xml_database_path
from cwe2.weakness import Weakness as DBWeakness
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.contrib.auth.models import UserManager
from django.core import exceptions
from django.core.exceptions import ValidationError
Expand Down Expand Up @@ -1452,14 +1453,18 @@ def _validate_username(self, email):


class ApiUser(UserModel):
"""
A User proxy model to facilitate simplified admin API user creation.
"""
"""A User proxy model to facilitate simplified admin API user creation."""

objects = ApiUserManager()

class Meta:
proxy = True
permissions = [
("throttle_unrestricted", "Can make api requests without throttling limits"),
("throttle_18000_hour", "Can make 18000 api requests per hour"),
("throttle_14400_hour", "Can make 14400 api requests per hour"),
("throttle_3600_hour", "Can make 3600 api requests per hour"),
]


class ChangeLog(models.Model):
Expand Down
25 changes: 11 additions & 14 deletions vulnerabilities/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
from urllib.parse import quote

from django.core.cache import cache
from django.test import TestCase
from django.test import TransactionTestCase
from django.test.client import RequestFactory
Expand Down Expand Up @@ -452,10 +453,8 @@ def add_aliases(vuln, aliases):

class APIPerformanceTest(TestCase):
def setUp(self):
self.user = ApiUser.objects.create_api_user(username="e@mail.com")
self.auth = f"Token {self.user.auth_token.key}"
cache.clear()
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth)

# This setup creates the following data:
# vulnerabilities: vul1, vul2, vul3
Expand Down Expand Up @@ -503,7 +502,7 @@ def setUp(self):
set_as_fixing(package=self.pkg_2_13_2, vulnerability=self.vul1)

def test_api_packages_all_num_queries(self):
with self.assertNumQueries(4):
with self.assertNumQueries(3):
# There are 4 queries:
# 1. SAVEPOINT
# 2. Authenticating user
Expand All @@ -519,22 +518,22 @@ def test_api_packages_all_num_queries(self):
]

def test_api_packages_single_num_queries(self):
with self.assertNumQueries(8):
with self.assertNumQueries(7):
self.csrf_client.get(f"/api/packages/{self.pkg_2_14_0_rc1.id}", format="json")

def test_api_packages_single_with_purl_in_query_num_queries(self):
with self.assertNumQueries(9):
with self.assertNumQueries(8):
self.csrf_client.get(f"/api/packages/?purl={self.pkg_2_14_0_rc1.purl}", format="json")

def test_api_packages_single_with_purl_no_version_in_query_num_queries(self):
with self.assertNumQueries(64):
with self.assertNumQueries(63):
self.csrf_client.get(
f"/api/packages/?purl=pkg:maven/com.fasterxml.jackson.core/jackson-databind",
format="json",
)

def test_api_packages_bulk_search(self):
with self.assertNumQueries(45):
with self.assertNumQueries(44):
packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1]
purls = [p.purl for p in packages]

Expand All @@ -547,7 +546,7 @@ def test_api_packages_bulk_search(self):
).json()

def test_api_packages_with_lookup(self):
with self.assertNumQueries(14):
with self.assertNumQueries(13):
data = {"purl": self.pkg_2_12_6.purl}

resp = self.csrf_client.post(
Expand All @@ -557,7 +556,7 @@ def test_api_packages_with_lookup(self):
).json()

def test_api_packages_bulk_lookup(self):
with self.assertNumQueries(45):
with self.assertNumQueries(44):
packages = [self.pkg_2_12_6, self.pkg_2_12_6_1, self.pkg_2_13_1]
purls = [p.purl for p in packages]

Expand All @@ -572,10 +571,8 @@ def test_api_packages_bulk_lookup(self):

class APITestCasePackage(TestCase):
def setUp(self):
self.user = ApiUser.objects.create_api_user(username="e@mail.com")
self.auth = f"Token {self.user.auth_token.key}"
cache.clear()
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.csrf_client.credentials(HTTP_AUTHORIZATION=self.auth)

# This setup creates the following data:
# vulnerabilities: vul1, vul2, vul3
Expand Down Expand Up @@ -766,7 +763,7 @@ def test_api_with_wrong_namespace_filter(self):
self.assertEqual(response["count"], 0)

def test_api_with_all_vulnerable_packages(self):
with self.assertNumQueries(4):
with self.assertNumQueries(3):
# There are 4 queries:
# 1. SAVEPOINT
# 2. Authenticating user
Expand Down
Loading