Skip to content

Commit c30c7b7

Browse files
authored
test: added test for generate_sbom function (#4060)
* test: added test for generate_sbom function Signed-off-by: Meet Soni <meetsoni3017@gmail.com>
1 parent f4c7e91 commit c30c7b7

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

cve_bin_tool/output_engine/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ def __init__(
713713
self.sbom_format = sbom_format
714714
self.sbom_root = sbom_root
715715
self.offline = offline
716+
self.sbom_packages = {}
716717

717718
def output_cves(self, outfile, output_type="console"):
718719
"""Output a list of CVEs
@@ -914,7 +915,6 @@ def generate_sbom(
914915
):
915916
"""Create SBOM package and generate SBOM file."""
916917
# Create SBOM
917-
sbom_packages = {}
918918
sbom_relationships = []
919919
my_package = SBOMPackage()
920920
sbom_relationship = SBOMRelationship()
@@ -933,7 +933,7 @@ def generate_sbom(
933933
my_package.set_supplier("UNKNOWN", "NOASSERTION")
934934

935935
# Store package data
936-
sbom_packages[(my_package.get_name(), my_package.get_value("version"))] = (
936+
self.sbom_packages[(my_package.get_name(), my_package.get_value("version"))] = (
937937
my_package.get_package()
938938
)
939939
sbom_relationship.initialise()
@@ -945,18 +945,18 @@ def generate_sbom(
945945
my_package.initialise()
946946
my_package.set_name(product_data.product)
947947
my_package.set_version(product_data.version)
948-
if product_data.vendor != "UNKNOWN":
948+
if product_data.vendor.casefold() != "UNKNOWN".casefold():
949949
my_package.set_supplier("Organization", product_data.vendor)
950950
my_package.set_licensedeclared(license)
951951
my_package.set_licenseconcluded(license)
952952
if not (
953953
(my_package.get_name(), my_package.get_value("version"))
954-
in sbom_packages
954+
in self.sbom_packages
955955
and product_data.vendor == "unknown"
956956
):
957957
location = product_data.location
958958
my_package.set_evidence(location) # Set location directly
959-
sbom_packages[
959+
self.sbom_packages[
960960
(my_package.get_name(), my_package.get_value("version"))
961961
] = my_package.get_package()
962962
sbom_relationship.initialise()
@@ -967,7 +967,7 @@ def generate_sbom(
967967

968968
# Generate SBOM
969969
my_sbom = SBOM()
970-
my_sbom.add_packages(sbom_packages)
970+
my_sbom.add_packages(self.sbom_packages)
971971
my_sbom.add_relationships(sbom_relationships)
972972
my_generator = SBOMGenerator(
973973
sbom_type=sbom_type,

test/test_output_engine.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import unittest
1313
from datetime import datetime
1414
from pathlib import Path
15+
from unittest.mock import MagicMock, call, patch
1516

1617
from jsonschema import validate
1718
from rich.console import Console
@@ -1124,6 +1125,20 @@ class TestOutputEngine(unittest.TestCase):
11241125
]
11251126

11261127
def setUp(self) -> None:
1128+
self.all_product_data = [
1129+
ProductInfo(
1130+
product="product1",
1131+
version="1.0",
1132+
vendor="VendorA",
1133+
location="/usr/local/bin/product",
1134+
),
1135+
ProductInfo(
1136+
product="product2",
1137+
version="2.0",
1138+
vendor="unknown",
1139+
location="/usr/local/bin/product",
1140+
),
1141+
]
11271142
self.output_engine = OutputEngine(
11281143
all_cve_data=self.MOCK_OUTPUT,
11291144
scanned_dir="",
@@ -1134,6 +1149,68 @@ def setUp(self) -> None:
11341149
)
11351150
self.mock_file = tempfile.NamedTemporaryFile("w+", encoding="utf-8")
11361151

1152+
def test_generate_sbom(self):
1153+
with patch(
1154+
"cve_bin_tool.output_engine.SBOMPackage"
1155+
) as mock_sbom_package, patch("cve_bin_tool.output_engine.SBOMRelationship"):
1156+
mock_package_instance = MagicMock()
1157+
mock_sbom_package.return_value = mock_package_instance
1158+
1159+
self.output_engine.generate_sbom(
1160+
all_product_data=self.all_product_data,
1161+
filename="test.sbom",
1162+
sbom_type="spdx",
1163+
sbom_format="tag",
1164+
sbom_root="CVE-SCAN",
1165+
)
1166+
1167+
# Assertions
1168+
mock_package_instance.set_name.assert_any_call("CVEBINTOOL-CVE-SCAN")
1169+
1170+
# Check if set_name is called for each product
1171+
expected_calls = [
1172+
call(product.product) for product in self.all_product_data
1173+
]
1174+
mock_package_instance.set_name.assert_has_calls(
1175+
expected_calls, any_order=True
1176+
)
1177+
1178+
# Check if set_version is called for each product
1179+
expected_calls = [
1180+
call(product.version) for product in self.all_product_data
1181+
]
1182+
mock_package_instance.set_version.assert_has_calls(
1183+
expected_calls, any_order=True
1184+
)
1185+
1186+
# Check if set_supplier is called for VendorA
1187+
mock_package_instance.set_supplier.assert_any_call(
1188+
"Organization", "VendorA"
1189+
)
1190+
1191+
for call_args in mock_package_instance.set_supplier.call_args_list:
1192+
args, _ = call_args
1193+
self.assertNotEqual(args, ("Organization", "unknown"))
1194+
1195+
# Check if set_licensedeclared and set_licenseconcluded are called for each product
1196+
expected_calls = [call("NOASSERTION")] * len(self.all_product_data)
1197+
mock_package_instance.set_licensedeclared.assert_has_calls(
1198+
expected_calls, any_order=True
1199+
)
1200+
mock_package_instance.set_licenseconcluded.assert_has_calls(
1201+
expected_calls, any_order=True
1202+
)
1203+
1204+
# Ensure packages are added to sbom_packages correctly
1205+
expected_packages = {
1206+
mock_package_instance.get_package.return_value,
1207+
mock_package_instance.get_package.return_value,
1208+
}
1209+
actual_packages = [
1210+
package for package in self.output_engine.sbom_packages.values()
1211+
]
1212+
self.assertEqual(actual_packages, list(expected_packages))
1213+
11371214
def tearDown(self) -> None:
11381215
self.mock_file.close()
11391216

0 commit comments

Comments
 (0)