Skip to content

Commit 7f27fc5

Browse files
committed
Add tests for improve_runner
Signed-off-by: Hritik Vijay <hritikxx8@gmail.com>
1 parent fdd17a7 commit 7f27fc5

File tree

3 files changed

+183
-8
lines changed

3 files changed

+183
-8
lines changed

vulnerabilities/improve_runner.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def run(self) -> None:
5353
@transaction.atomic
5454
def process_inferences(inferences: List[Inference], advisory: Advisory, improver_name: str):
5555
"""
56+
Return number of inferences processed.
5657
An atomic transaction that updates both the Advisory (e.g. date_improved)
5758
and processes the given inferences to create or update corresponding
5859
database fields.
@@ -61,10 +62,11 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
6162
erroneous. Also, the atomic transaction for every advisory and its
6263
inferences makes sure that date_improved of advisory is consistent.
6364
"""
65+
inferences_processed_count = 0
6466

6567
if not inferences:
66-
logger.warn(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}")
67-
return
68+
logger.warning(f"Nothing to improve. Source: {improver_name} Advisory id: {advisory.id}")
69+
return inferences_processed_count
6870

6971
logger.info(f"Improving advisory id: {advisory.id}")
7072

@@ -76,7 +78,7 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
7678
)
7779

7880
if not vulnerability:
79-
logger.warn(f"Unable to get vulnerability for inference: {inference!r}")
81+
logger.warning(f"Unable to get vulnerability for inference: {inference!r}")
8082
continue
8183

8284
for ref in inference.references:
@@ -139,8 +141,12 @@ def process_inferences(inferences: List[Inference], advisory: Advisory, improver
139141
cwe_obj, created = Weakness.objects.get_or_create(cwe_id=cwe_id)
140142
cwe_obj.vulnerabilities.add(vulnerability)
141143
cwe_obj.save()
144+
145+
inferences_processed_count += 1
146+
142147
advisory.date_improved = datetime.now(timezone.utc)
143148
advisory.save()
149+
return inferences_processed_count
144150

145151

146152
def create_valid_vulnerability_reference(url, reference_id=None):
@@ -164,7 +170,7 @@ def create_valid_vulnerability_reference(url, reference_id=None):
164170
return reference
165171

166172

167-
def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summary):
173+
def get_or_create_vulnerability_and_aliases(alias_names, vulnerability_id=None, summary=None):
168174
"""
169175
Get or create vulnerabilitiy and aliases such that all existing and new
170176
aliases point to the same vulnerability
@@ -184,7 +190,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
184190
# TODO: It is possible that all those vulnerabilities are actually
185191
# the same at data level, figure out a way to merge them
186192
if len(existing_vulns) > 1:
187-
logger.warn(
193+
logger.warning(
188194
f"Given aliases {alias_names} already exist and do not point "
189195
f"to a single vulnerability. Cannot improve. Skipped."
190196
)
@@ -197,7 +203,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
197203
and vulnerability_id
198204
and existing_alias_vuln.vulnerability_id != vulnerability_id
199205
):
200-
logger.warn(
206+
logger.warning(
201207
f"Given aliases {alias_names!r} already exist and point to existing"
202208
f"vulnerability {existing_alias_vuln}. Unable to create Vulnerability "
203209
f"with vulnerability_id {vulnerability_id}. Skipped"
@@ -210,7 +216,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
210216
try:
211217
vulnerability = Vulnerability.objects.get(vulnerability_id=vulnerability_id)
212218
except Vulnerability.DoesNotExist:
213-
logger.warn(
219+
logger.warning(
214220
f"Given vulnerability_id: {vulnerability_id} does not exist in the database"
215221
)
216222
return
@@ -219,7 +225,7 @@ def get_or_create_vulnerability_and_aliases(vulnerability_id, alias_names, summa
219225
vulnerability.save()
220226

221227
if summary and summary != vulnerability.summary:
222-
logger.warn(
228+
logger.warning(
223229
f"Inconsistent summary for {vulnerability!r}. "
224230
f"Existing: {vulnerability.summary}, provided: {summary}"
225231
)

vulnerabilities/tests/test_improve_runner.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,27 @@
77
# See https://aboutcode.org for more information about nexB OSS projects.
88
#
99

10+
from collections import Counter
11+
1012
import pytest
13+
from django.utils import timezone
14+
from packageurl import PackageURL
15+
from pytest_django.asserts import assertQuerysetEqual
1116

17+
from vulnerabilities.importer import Reference
1218
from vulnerabilities.improve_runner import create_valid_vulnerability_reference
19+
from vulnerabilities.improve_runner import get_or_create_vulnerability_and_aliases
20+
from vulnerabilities.improve_runner import process_inferences
21+
from vulnerabilities.improver import Improver
22+
from vulnerabilities.improver import Inference
23+
from vulnerabilities.models import Advisory
24+
from vulnerabilities.models import Alias
25+
from vulnerabilities.models import Package
26+
from vulnerabilities.models import PackageRelatedVulnerability
27+
from vulnerabilities.models import Vulnerability
28+
from vulnerabilities.models import VulnerabilityReference
29+
from vulnerabilities.models import VulnerabilityRelatedReference
30+
from vulnerabilities.models import VulnerabilitySeverity
1331

1432

1533
@pytest.mark.django_db
@@ -37,3 +55,152 @@ def test_create_valid_vulnerability_reference_accepts_long_references():
3755
url="https://foo.bar",
3856
)
3957
assert result
58+
59+
60+
@pytest.mark.django_db
61+
def test_get_or_create_vulnerability_and_aliases_with_new_vulnerability_and_new_aliases():
62+
alias_names = ["TAYLOR-1337", "SWIFT-1337"]
63+
summary = "Melodious vulnerability"
64+
vulnerability = get_or_create_vulnerability_and_aliases(
65+
alias_names=alias_names, summary=summary
66+
)
67+
assert vulnerability
68+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
69+
assert Counter(alias_names_in_db) == Counter(alias_names)
70+
71+
72+
@pytest.mark.django_db
73+
def test_get_or_create_vulnerability_and_aliases_with_different_vulnerability_and_existing_aliases():
74+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
75+
existing_vulnerability.save()
76+
existing_aliases = []
77+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
78+
for alias in existing_alias_names:
79+
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
80+
Alias.objects.bulk_create(existing_aliases)
81+
82+
different_vulnerability = Vulnerability(vulnerability_id="VCID-New")
83+
different_vulnerability.save()
84+
assert not get_or_create_vulnerability_and_aliases(
85+
alias_names=existing_alias_names, vulnerability_id=different_vulnerability.vulnerability_id
86+
)
87+
88+
89+
@pytest.mark.django_db
90+
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_new_aliases():
91+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
92+
existing_vulnerability.save()
93+
94+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
95+
vulnerability = get_or_create_vulnerability_and_aliases(
96+
vulnerability_id="VCID-Existing", alias_names=existing_alias_names
97+
)
98+
assert existing_vulnerability == vulnerability
99+
100+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
101+
assert Counter(alias_names_in_db) == Counter(existing_alias_names)
102+
103+
104+
@pytest.mark.django_db
105+
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_aliases():
106+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
107+
existing_vulnerability.save()
108+
109+
existing_aliases = []
110+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
111+
for alias in existing_alias_names:
112+
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
113+
Alias.objects.bulk_create(existing_aliases)
114+
115+
vulnerability = get_or_create_vulnerability_and_aliases(
116+
vulnerability_id="VCID-Existing", alias_names=existing_alias_names
117+
)
118+
assert existing_vulnerability == vulnerability
119+
120+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
121+
assert Counter(alias_names_in_db) == Counter(existing_alias_names)
122+
123+
124+
@pytest.mark.django_db
125+
def test_get_or_create_vulnerability_and_aliases_with_existing_vulnerability_and_existing_and_new_aliases():
126+
existing_vulnerability = Vulnerability(vulnerability_id="VCID-Existing")
127+
existing_vulnerability.save()
128+
129+
existing_aliases = []
130+
existing_alias_names = ["ALIAS-1", "ALIAS-2"]
131+
for alias in existing_alias_names:
132+
existing_aliases.append(Alias(alias=alias, vulnerability=existing_vulnerability))
133+
Alias.objects.bulk_create(existing_aliases)
134+
135+
new_alias_names = ["ALIAS-3", "ALIAS-4"]
136+
alias_names = existing_alias_names + new_alias_names
137+
vulnerability = get_or_create_vulnerability_and_aliases(
138+
vulnerability_id="VCID-Existing", alias_names=alias_names
139+
)
140+
assert existing_vulnerability == vulnerability
141+
142+
alias_names_in_db = vulnerability.get_aliases.values_list("alias", flat=True)
143+
assert Counter(alias_names_in_db) == Counter(alias_names)
144+
145+
146+
DUMMY_ADVISORY = Advisory(summary="dummy", created_by="tests", date_collected=timezone.now())
147+
148+
149+
@pytest.mark.django_db
150+
def test_process_inferences_with_no_inference():
151+
assert not process_inferences(
152+
inferences=[], advisory=DUMMY_ADVISORY, improver_name="test_improver"
153+
)
154+
155+
156+
@pytest.mark.django_db
157+
def test_process_inferences_with_unknown_but_specified_vulnerability():
158+
inference = Inference(vulnerability_id="VCID-Does-Not-Exist-In-DB", aliases=["MATRIX-Neo"])
159+
assert not process_inferences(
160+
inferences=[inference], advisory=DUMMY_ADVISORY, improver_name="test_improver"
161+
)
162+
163+
164+
INFERENCES = [
165+
Inference(
166+
aliases=["CVE-1", "CVE-2"],
167+
summary="One upon a time, in a package far far away",
168+
affected_purls=[
169+
PackageURL(type="character", namespace="star-wars", name="anakin", version="1")
170+
],
171+
fixed_purl=PackageURL(
172+
type="character", namespace="star-wars", name="darth-vader", version="1"
173+
),
174+
references=[Reference(reference_id="imperial-vessel-1", url="https://m47r1x.github.io")],
175+
)
176+
]
177+
178+
179+
def get_objects_in_all_tables_used_by_process_inferences():
180+
return {
181+
"vulnerabilities": list(Vulnerability.objects.all()),
182+
"aliases": list(Alias.objects.all()),
183+
"references": list(VulnerabilityReference.objects.all()),
184+
"advisories": list(Advisory.objects.all()),
185+
"packages": list(Package.objects.all()),
186+
"references": list(VulnerabilityReference.objects.all()),
187+
"severity": list(VulnerabilitySeverity.objects.all()),
188+
}
189+
190+
191+
@pytest.mark.django_db
192+
def test_process_inferences_idempotency():
193+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
194+
all_objects = get_objects_in_all_tables_used_by_process_inferences()
195+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
196+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver")
197+
assert all_objects == get_objects_in_all_tables_used_by_process_inferences()
198+
199+
200+
@pytest.mark.django_db
201+
def test_process_inference_idempotency_with_different_improver_names():
202+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_one")
203+
all_objects = get_objects_in_all_tables_used_by_process_inferences()
204+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_two")
205+
process_inferences(INFERENCES, DUMMY_ADVISORY, improver_name="test_improver_three")
206+
assert all_objects == get_objects_in_all_tables_used_by_process_inferences()

vulnerabilities/tests/test_improver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def test_inference_to_dict_method_with_vulnerability_id():
3131
"affected_purls": [],
3232
"fixed_purl": None,
3333
"references": [],
34+
"weaknesses": [],
3435
}
3536
assert expected == inference.to_dict()
3637

@@ -46,6 +47,7 @@ def test_inference_to_dict_method_with_purls():
4647
"affected_purls": [purl.to_dict()],
4748
"fixed_purl": purl.to_dict(),
4849
"references": [],
50+
"weaknesses": [],
4951
}
5052
assert expected == inference.to_dict()
5153

0 commit comments

Comments
 (0)