Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ func loadDatabases(
listPackages bool,
offline bool,
batchSize int,
pkgNames []string,
) (OSVDatabases, bool) {
dbs := make(OSVDatabases, 0, len(dbConfigs))

Expand All @@ -202,7 +203,7 @@ func loadDatabases(
for _, dbConfig := range dbConfigs {
r.PrintTextf(" %s", dbConfig.Name)

db, err := database.Load(dbConfig, offline, batchSize)
db, err := database.Load(dbConfig, offline, batchSize, pkgNames)

if err != nil {
r.PrintDatabaseLoadErr(err)
Expand Down Expand Up @@ -591,12 +592,20 @@ This flag can be passed multiple times to ignore different vulnerabilities`)

files.adjustExtraDatabases(*noConfigDatabases, *useAPI, *useDatabases)

var allPackages []string
for _, p := range files {
for _, pkg := range p.lockf.Packages {
allPackages = append(allPackages, pkg.Name)
}
}

dbs, errored := loadDatabases(
r,
uniqueDBConfigs(files.getConfigs()),
*listPackages,
*offline,
*batchSize,
allPackages,
)

if errored {
Expand Down
6 changes: 3 additions & 3 deletions pkg/database/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ func (dbc Config) Identifier() string {
var ErrUnsupportedDatabaseType = errors.New("unsupported database source type")

// Load initializes a new OSV database based on the given Config
func Load(config Config, offline bool, batchSize int) (DB, error) {
func Load(config Config, offline bool, batchSize int, pkgNames []string) (DB, error) {
switch config.Type {
case "zip":
return NewZippedDB(config, offline)
return NewZippedDB(config, offline, pkgNames)
case "api":
return NewAPIDB(config, offline, batchSize)
case "dir":
return NewDirDB(config, offline)
return NewDirDB(config, offline, pkgNames)
}

return nil, fmt.Errorf("%w %s", ErrUnsupportedDatabaseType, config.Type)
Expand Down
8 changes: 4 additions & 4 deletions pkg/database/dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var ErrDirPathWrongProtocol = errors.New("directory path must start with \"file:

// load walks the filesystem starting with the working directory within the local path,
// loading all OSVs found along the way.
func (db *DirDB) load() error {
func (db *DirDB) load(pkgNames []string) error {
db.vulnerabilities = make(map[string][]OSV)

if !strings.HasPrefix(db.LocalPath, "file:") {
Expand Down Expand Up @@ -78,7 +78,7 @@ func (db *DirDB) load() error {
return nil
}

db.addVulnerability(pa)
db.addVulnerability(pa, pkgNames)

return nil
})
Expand All @@ -94,15 +94,15 @@ func (db *DirDB) load() error {
return nil
}

func NewDirDB(config Config, offline bool) (*DirDB, error) {
func NewDirDB(config Config, offline bool, pkgNames []string) (*DirDB, error) {
db := &DirDB{
name: config.Name,
identifier: config.Identifier(),
LocalPath: config.URL,
WorkingDirectory: config.WorkingDirectory,
Offline: offline,
}
if err := db.load(); err != nil {
if err := db.load(pkgNames); err != nil {
return nil, fmt.Errorf("unable to load OSV database: %w", err)
}

Expand Down
38 changes: 33 additions & 5 deletions pkg/database/dir_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ func TestNewDirDB(t *testing.T) {
osvs := []database.OSV{
withDefaultAffected("OSV-1"),
withDefaultAffected("OSV-2"),
{
ID: "OSV-3",
Affected: []database.Affected{
{Package: database.Package{Ecosystem: "PyPi", Name: "mine2"}, Versions: database.Versions{}},
},
},
{
ID: "GHSA-1234",
Affected: []database.Affected{
Expand All @@ -22,7 +28,7 @@ func TestNewDirDB(t *testing.T) {
},
}

db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db"}, false)
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db"}, false, nil)

if err != nil {
t.Fatalf("unexpected error \"%v\"", err)
Expand All @@ -34,7 +40,7 @@ func TestNewDirDB(t *testing.T) {
func TestNewDirDB_InvalidURI(t *testing.T) {
t.Parallel()

db, err := database.NewDirDB(database.Config{URL: "file://\\"}, false)
db, err := database.NewDirDB(database.Config{URL: "file://\\"}, false, nil)

if err == nil {
t.Fatalf("NewDirDB() did not return expected error")
Expand All @@ -48,7 +54,7 @@ func TestNewDirDB_InvalidURI(t *testing.T) {
func TestNewDirDB_NotFileProtocol(t *testing.T) {
t.Parallel()

db, err := database.NewDirDB(database.Config{URL: "https://mysite.com/my.zip"}, false)
db, err := database.NewDirDB(database.Config{URL: "https://mysite.com/my.zip"}, false, nil)

if err == nil {
t.Fatalf("NewDirDB() did not return expected error")
Expand All @@ -66,7 +72,7 @@ func TestNewDirDB_NotFileProtocol(t *testing.T) {
func TestNewDirDB_DoesNotExist(t *testing.T) {
t.Parallel()

db, err := database.NewDirDB(database.Config{URL: "file:/testdata/nowhere"}, false)
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/nowhere"}, false, nil)

if err == nil {
t.Fatalf("NewDirDB() did not return expected error")
Expand All @@ -82,11 +88,33 @@ func TestNewDirDB_WorkingDirectory(t *testing.T) {

osvs := []database.OSV{withDefaultAffected("OSV-1")}

db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db", WorkingDirectory: "nested-1"}, false)
db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db", WorkingDirectory: "nested-1"}, false, nil)

if err != nil {
t.Fatalf("unexpected error \"%v\"", err)
}

expectDBToHaveOSVs(t, db, osvs)
}

func TestNewDirDB_WithSpecificPackages(t *testing.T) {
t.Parallel()

db, err := database.NewDirDB(database.Config{URL: "file:/testdata/db"}, false, []string{"mine", "request"})

if err != nil {
t.Fatalf("unexpected error \"%v\"", err)
}

expectDBToHaveOSVs(t, db, []database.OSV{
withDefaultAffected("OSV-1"),
withDefaultAffected("OSV-2"),
{
ID: "GHSA-1234",
Affected: []database.Affected{
{Package: database.Package{Ecosystem: "npm", Name: "request"}},
{Package: database.Package{Ecosystem: "npm", Name: "@cypress/request"}},
},
},
})
}
4 changes: 2 additions & 2 deletions pkg/database/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestLoad(t *testing.T) {
}

for _, typ := range types {
_, err := database.Load(database.Config{Type: typ}, false, 100)
_, err := database.Load(database.Config{Type: typ}, false, 100, nil)

if err == nil {
t.Fatalf("NewDirDB() did not return expected error")
Expand All @@ -28,7 +28,7 @@ func TestLoad(t *testing.T) {
func TestLoad_BadType(t *testing.T) {
t.Parallel()

db, err := database.Load(database.Config{Type: "file"}, false, 100)
db, err := database.Load(database.Config{Type: "file"}, false, 100, nil)

if err == nil {
t.Fatalf("NewDirDB() did not return expected error")
Expand Down
8 changes: 7 additions & 1 deletion pkg/database/mem-check.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ type memDB struct {
VulnerabilitiesCount int
}

func (db *memDB) addVulnerability(osv OSV) {
func (db *memDB) addVulnerability(osv OSV, pkgNames []string) {
db.VulnerabilitiesCount++

// if we have been provided a list of package names, only load advisories
// that might actually affect those packages, rather than all advisories
if len(pkgNames) != 0 && !mightAffectPackages(osv, pkgNames) {
return
}

for _, affected := range osv.Affected {
hash := string(affected.Package.Ecosystem) + "-" + affected.Package.NormalizedName()
vulns := db.vulnerabilities[hash]
Expand Down
12 changes: 12 additions & 0 deletions pkg/database/testdata/db/nested-2/osv-3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"id": "OSV-3",
"affected": [
{
"package": {
"name": "mine2",
"ecosystem": "PyPi"
},
"versions": []
}
]
}
24 changes: 18 additions & 6 deletions pkg/database/zip.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,21 @@ func (db *ZipDB) fetchZip() ([]byte, error) {
return body, nil
}

func mightAffectPackages(v OSV, names []string) bool {
for _, affected := range v.Affected {
for _, name := range names {
if affected.Package.Name == name {
return true
}
}
}

return false
}

// Loads the given zip file into the database as an OSV.
// It is assumed that the file is JSON and in the working directory of the db
func (db *ZipDB) loadZipFile(zipFile *zip.File) {
func (db *ZipDB) loadZipFile(zipFile *zip.File, pkgNames []string) {
file, err := zipFile.Open()
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Could not read %s: %v\n", zipFile.Name, err)
Expand All @@ -152,7 +164,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
return
}

db.addVulnerability(osv)
db.addVulnerability(osv, pkgNames)
}

// load fetches a zip archive of the OSV database and loads known vulnerabilities
Expand All @@ -161,7 +173,7 @@ func (db *ZipDB) loadZipFile(zipFile *zip.File) {
// Internally, the archive is cached along with the date that it was fetched
// so that a new version of the archive is only downloaded if it has been
// modified, per HTTP caching standards.
func (db *ZipDB) load() error {
func (db *ZipDB) load(pkgNames []string) error {
db.vulnerabilities = make(map[string][]OSV)

body, err := db.fetchZip()
Expand All @@ -185,21 +197,21 @@ func (db *ZipDB) load() error {
continue
}

db.loadZipFile(zipFile)
db.loadZipFile(zipFile, pkgNames)
}

return nil
}

func NewZippedDB(config Config, offline bool) (*ZipDB, error) {
func NewZippedDB(config Config, offline bool, pkgNames []string) (*ZipDB, error) {
db := &ZipDB{
name: config.Name,
identifier: config.Identifier(),
ArchiveURL: config.URL,
WorkingDirectory: config.WorkingDirectory,
Offline: offline,
}
if err := db.load(); err != nil {
if err := db.load(pkgNames); err != nil {
return nil, fmt.Errorf("unable to fetch OSV database: %w", err)
}

Expand Down
Loading