1212import unittest
1313from datetime import datetime
1414from pathlib import Path
15+ from unittest .mock import MagicMock , call , patch
1516
1617from jsonschema import validate
1718from rich .console import Console
@@ -1124,6 +1125,20 @@ class TestOutputEngine(unittest.TestCase):
11241125 ]
11251126
11261127 def setUp (self ) -> None :
1128+ self .all_product_data = [
1129+ ProductInfo (
1130+ product = "product1" ,
1131+ version = "1.0" ,
1132+ vendor = "VendorA" ,
1133+ location = "/usr/local/bin/product" ,
1134+ ),
1135+ ProductInfo (
1136+ product = "product2" ,
1137+ version = "2.0" ,
1138+ vendor = "unknown" ,
1139+ location = "/usr/local/bin/product" ,
1140+ ),
1141+ ]
11271142 self .output_engine = OutputEngine (
11281143 all_cve_data = self .MOCK_OUTPUT ,
11291144 scanned_dir = "" ,
@@ -1134,6 +1149,68 @@ def setUp(self) -> None:
11341149 )
11351150 self .mock_file = tempfile .NamedTemporaryFile ("w+" , encoding = "utf-8" )
11361151
1152+ def test_generate_sbom (self ):
1153+ with patch (
1154+ "cve_bin_tool.output_engine.SBOMPackage"
1155+ ) as mock_sbom_package , patch ("cve_bin_tool.output_engine.SBOMRelationship" ):
1156+ mock_package_instance = MagicMock ()
1157+ mock_sbom_package .return_value = mock_package_instance
1158+
1159+ self .output_engine .generate_sbom (
1160+ all_product_data = self .all_product_data ,
1161+ filename = "test.sbom" ,
1162+ sbom_type = "spdx" ,
1163+ sbom_format = "tag" ,
1164+ sbom_root = "CVE-SCAN" ,
1165+ )
1166+
1167+ # Assertions
1168+ mock_package_instance .set_name .assert_any_call ("CVEBINTOOL-CVE-SCAN" )
1169+
1170+ # Check if set_name is called for each product
1171+ expected_calls = [
1172+ call (product .product ) for product in self .all_product_data
1173+ ]
1174+ mock_package_instance .set_name .assert_has_calls (
1175+ expected_calls , any_order = True
1176+ )
1177+
1178+ # Check if set_version is called for each product
1179+ expected_calls = [
1180+ call (product .version ) for product in self .all_product_data
1181+ ]
1182+ mock_package_instance .set_version .assert_has_calls (
1183+ expected_calls , any_order = True
1184+ )
1185+
1186+ # Check if set_supplier is called for VendorA
1187+ mock_package_instance .set_supplier .assert_any_call (
1188+ "Organization" , "VendorA"
1189+ )
1190+
1191+ for call_args in mock_package_instance .set_supplier .call_args_list :
1192+ args , _ = call_args
1193+ self .assertNotEqual (args , ("Organization" , "unknown" ))
1194+
1195+ # Check if set_licensedeclared and set_licenseconcluded are called for each product
1196+ expected_calls = [call ("NOASSERTION" )] * len (self .all_product_data )
1197+ mock_package_instance .set_licensedeclared .assert_has_calls (
1198+ expected_calls , any_order = True
1199+ )
1200+ mock_package_instance .set_licenseconcluded .assert_has_calls (
1201+ expected_calls , any_order = True
1202+ )
1203+
1204+ # Ensure packages are added to sbom_packages correctly
1205+ expected_packages = {
1206+ mock_package_instance .get_package .return_value ,
1207+ mock_package_instance .get_package .return_value ,
1208+ }
1209+ actual_packages = [
1210+ package for package in self .output_engine .sbom_packages .values ()
1211+ ]
1212+ self .assertEqual (actual_packages , list (expected_packages ))
1213+
11371214 def tearDown (self ) -> None :
11381215 self .mock_file .close ()
11391216
0 commit comments