Skip to content

Commit ab2df17

Browse files
committed
Add tests for v2 API
Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>
1 parent f3fbd24 commit ab2df17

File tree

3 files changed

+379
-53
lines changed

3 files changed

+379
-53
lines changed

vulnerabilities/api.py

Lines changed: 81 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ class AliasViewSet(VulnerabilityViewSet):
691691

692692
filterset_class = AliasFilterSet
693693

694+
694695
class WeaknessV2Serializer(serializers.ModelSerializer):
695696
cwe_id = serializers.CharField()
696697
name = serializers.CharField()
@@ -700,16 +701,6 @@ class Meta:
700701
model = Weakness
701702
fields = ["cwe_id", "name", "description"]
702703

703-
class VulnerabilityFilter(filters.FilterSet):
704-
vulnerability_id = filters.CharFilter(field_name='vulnerability_id', lookup_expr='exact')
705-
vulnerability_id__in = filters.BaseInFilter(field_name='vulnerability_id', lookup_expr='in')
706-
alias = filters.CharFilter(field_name='aliases__alias', lookup_expr='exact')
707-
alias__in = filters.BaseInFilter(field_name='aliases__alias', lookup_expr='in')
708-
709-
class Meta:
710-
model = Vulnerability
711-
fields = ['vulnerability_id', 'vulnerability_id__in', 'alias', 'alias__in']
712-
713704

714705
class VulnerabilityReferenceV2Serializer(serializers.ModelSerializer):
715706
url = serializers.CharField()
@@ -720,10 +711,11 @@ class Meta:
720711
model = VulnerabilityReference
721712
fields = ["url", "reference_type", "reference_id"]
722713

714+
723715
class VulnerabilityV2Serializer(BaseResourceSerializer):
724716
aliases = serializers.SerializerMethodField()
725717
weaknesses = WeaknessV2Serializer(many=True)
726-
references = VulnerabilityReferenceV2Serializer(many=True, source='vulnerabilityreference_set')
718+
references = VulnerabilityReferenceV2Serializer(many=True, source="vulnerabilityreference_set")
727719
severities = VulnerabilitySeveritySerializer(many=True)
728720

729721
class Meta:
@@ -744,52 +736,74 @@ def get_severities(self, obj):
744736
return obj.severities
745737

746738

739+
class VulnerabilityListSerializer(serializers.ModelSerializer):
740+
url = serializers.SerializerMethodField()
741+
742+
class Meta:
743+
model = Vulnerability
744+
fields = ["vulnerability_id", "url"]
745+
746+
def get_url(self, obj):
747+
request = self.context.get("request")
748+
return reverse(
749+
"vulnerability-v2-detail",
750+
kwargs={"vulnerability_id": obj.vulnerability_id},
751+
request=request,
752+
)
753+
754+
747755
class VulnerabilityV2ViewSet(viewsets.ReadOnlyModelViewSet):
748756
queryset = Vulnerability.objects.all()
749757
serializer_class = VulnerabilityV2Serializer
758+
lookup_field = "vulnerability_id"
759+
760+
def get_queryset(self):
761+
queryset = super().get_queryset()
762+
vulnerability_ids = self.request.query_params.getlist("vulnerability_id")
763+
aliases = self.request.query_params.getlist("alias")
764+
765+
if vulnerability_ids:
766+
queryset = queryset.filter(vulnerability_id__in=vulnerability_ids)
767+
768+
if aliases:
769+
queryset = queryset.filter(aliases__alias__in=aliases).distinct()
770+
771+
return queryset
772+
773+
def get_serializer_class(self):
774+
if self.action == "list":
775+
return VulnerabilityListSerializer
776+
return super().get_serializer_class()
750777

751778
def list(self, request, *args, **kwargs):
752779
queryset = self.get_queryset()
753-
# Apply pagination
780+
vulnerability_ids = request.query_params.getlist("vulnerability_id")
781+
782+
# If exactly one vulnerability_id is provided, return the serialized data
783+
if len(vulnerability_ids) == 1:
784+
try:
785+
vulnerability = queryset.get(vulnerability_id=vulnerability_ids[0])
786+
serializer = self.get_serializer(vulnerability)
787+
return Response(serializer.data)
788+
except Vulnerability.DoesNotExist:
789+
return Response({"detail": "Not found."}, status=404)
790+
791+
# Otherwise, return a dictionary of vulnerabilities keyed by vulnerability_id
754792
page = self.paginate_queryset(queryset)
755793
if page is not None:
756794
serializer = self.get_serializer(page, many=True)
757795
data = serializer.data
758-
vulnerabilities = {item['vulnerability_id']: item for item in data}
759-
# Use 'self.get_paginated_response' to include pagination data
760-
return self.get_paginated_response({'vulnerabilities': vulnerabilities})
796+
vulnerabilities = {item["vulnerability_id"]: item for item in data}
797+
return self.get_paginated_response({"vulnerabilities": vulnerabilities})
761798

762-
# If pagination is not applied
763799
serializer = self.get_serializer(queryset, many=True)
764800
data = serializer.data
765-
vulnerabilities = {item['vulnerability_id']: item for item in data}
766-
return Response({'vulnerabilities': vulnerabilities})
767-
768-
769-
class PackageFilter(filters.FilterSet):
770-
purl = filters.CharFilter(field_name='package_url', lookup_expr='exact')
771-
purl__in = filters.BaseInFilter(field_name='package_url', lookup_expr='in')
772-
affected_by_vulnerability = filters.CharFilter(
773-
field_name='affected_by_vulnerabilities__vulnerability_id',
774-
lookup_expr='exact'
775-
)
776-
fixing_vulnerability = filters.CharFilter(
777-
field_name='fixing_vulnerabilities__vulnerability_id',
778-
lookup_expr='exact'
779-
)
780-
781-
class Meta:
782-
model = Package
783-
fields = [
784-
'purl',
785-
'purl__in',
786-
'affected_by_vulnerability',
787-
'fixing_vulnerability',
788-
]
801+
vulnerabilities = {item["vulnerability_id"]: item for item in data}
802+
return Response({"vulnerabilities": vulnerabilities})
789803

790804

791805
class PackageV2Serializer(serializers.ModelSerializer):
792-
purl = serializers.CharField(source='package_url')
806+
purl = serializers.CharField(source="package_url")
793807
affected_by_vulnerabilities = serializers.SerializerMethodField()
794808
fixing_vulnerabilities = serializers.SerializerMethodField()
795809
next_non_vulnerable_version = serializers.CharField(read_only=True)
@@ -798,11 +812,11 @@ class PackageV2Serializer(serializers.ModelSerializer):
798812
class Meta:
799813
model = Package
800814
fields = [
801-
'purl',
802-
'affected_by_vulnerabilities',
803-
'fixing_vulnerabilities',
804-
'next_non_vulnerable_version',
805-
'latest_non_vulnerable_version',
815+
"purl",
816+
"affected_by_vulnerabilities",
817+
"fixing_vulnerabilities",
818+
"next_non_vulnerable_version",
819+
"latest_non_vulnerable_version",
806820
]
807821

808822
def get_affected_by_vulnerabilities(self, obj):
@@ -815,19 +829,36 @@ def get_fixing_vulnerabilities(self, obj):
815829
class PackageV2ViewSet(viewsets.ReadOnlyModelViewSet):
816830
queryset = Package.objects.all()
817831
serializer_class = PackageV2Serializer
818-
filterset_class = PackageFilter
832+
833+
def get_queryset(self):
834+
queryset = super().get_queryset()
835+
package_purls = self.request.query_params.getlist("purl")
836+
affected_by_vulnerability = self.request.query_params.get("affected_by_vulnerability")
837+
fixing_vulnerability = self.request.query_params.get("fixing_vulnerability")
838+
839+
if package_purls:
840+
queryset = queryset.filter(package_url__in=package_purls)
841+
if affected_by_vulnerability:
842+
queryset = queryset.filter(
843+
affected_by_vulnerabilities__vulnerability_id=affected_by_vulnerability
844+
)
845+
if fixing_vulnerability:
846+
queryset = queryset.filter(
847+
fixing_vulnerabilities__vulnerability_id=fixing_vulnerability
848+
)
849+
return queryset.with_is_vulnerable()
819850

820851
def list(self, request, *args, **kwargs):
821-
queryset = self.get_queryset().with_is_vulnerable()
852+
queryset = self.get_queryset()
822853
# Apply pagination
823854
page = self.paginate_queryset(queryset)
824855
if page is not None:
825856
serializer = self.get_serializer(page, many=True)
826857
data = serializer.data
827858
# Use 'self.get_paginated_response' to include pagination data
828-
return self.get_paginated_response({'purls': data})
859+
return self.get_paginated_response({"purls": data})
829860

830861
# If pagination is not applied
831862
serializer = self.get_serializer(queryset, many=True)
832863
data = serializer.data
833-
return Response({'purls': data})
864+
return Response({"purls": data})

0 commit comments

Comments
 (0)