Skip to content

Commit

Permalink
feat(misconf): Add support for using spec from on-disk bundle (#7179)
Browse files Browse the repository at this point in the history
  • Loading branch information
simar7 authored Aug 27, 2024
1 parent 45a9627 commit be86126
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 5 deletions.
41 changes: 37 additions & 4 deletions pkg/compliance/spec/compliance.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package spec
import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/samber/lo"
Expand All @@ -11,6 +12,7 @@ import (

sp "github.com/aquasecurity/trivy-checks/pkg/spec"
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/types"
)

Expand Down Expand Up @@ -70,18 +72,41 @@ func scannerByCheckID(checkID string) types.Scanner {
}
}

func checksDir(cacheDir string) string {
return filepath.Join(cacheDir, "policy")
}

func complianceSpecDir(cacheDir string) string {
return filepath.Join(checksDir(cacheDir), "content", "specs", "compliance")
}

// GetComplianceSpec accepct compliance flag name/path and return builtin or file system loaded spec
func GetComplianceSpec(specNameOrPath string) (ComplianceSpec, error) {
func GetComplianceSpec(specNameOrPath, cacheDir string) (ComplianceSpec, error) {
if specNameOrPath == "" {
return ComplianceSpec{}, nil
}

var b []byte
var err error
if strings.HasPrefix(specNameOrPath, "@") {
if strings.HasPrefix(specNameOrPath, "@") { // load user specified spec from disk
b, err = os.ReadFile(strings.TrimPrefix(specNameOrPath, "@"))
if err != nil {
return ComplianceSpec{}, fmt.Errorf("error retrieving compliance spec from path: %w", err)
}
log.Debug("Compliance spec loaded from specified path", log.String("path", specNameOrPath))
} else {
// TODO: GetSpecByName() should return []byte
b = []byte(sp.NewSpecLoader().GetSpecByName(specNameOrPath))
_, err := os.Stat(filepath.Join(checksDir(cacheDir), "metadata.json"))
if err != nil { // cache corrupt or bundle does not exist, load embedded version
b = []byte(sp.NewSpecLoader().GetSpecByName(specNameOrPath))
log.Debug("Compliance spec loaded from embedded library", log.String("spec", specNameOrPath))
} else {
// load from bundle on disk
b, err = LoadFromBundle(cacheDir, specNameOrPath)
if err != nil {
return ComplianceSpec{}, err
}
log.Debug("Compliance spec loaded from disk bundle", log.String("spec", specNameOrPath))
}
}

var complianceSpec ComplianceSpec
Expand All @@ -91,3 +116,11 @@ func GetComplianceSpec(specNameOrPath string) (ComplianceSpec, error) {
return complianceSpec, nil

}

func LoadFromBundle(cacheDir, specNameOrPath string) ([]byte, error) {
b, err := os.ReadFile(filepath.Join(complianceSpecDir(cacheDir), specNameOrPath+".yaml"))
if err != nil {
return nil, fmt.Errorf("error retrieving compliance spec from bundle %s: %w", specNameOrPath, err)
}
return b, nil
}
52 changes: 52 additions & 0 deletions pkg/compliance/spec/compliance_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package spec_test

import (
"path/filepath"
"sort"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/aquasecurity/trivy/pkg/compliance/spec"
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
Expand Down Expand Up @@ -239,3 +241,53 @@ func TestComplianceSpec_CheckIDs(t *testing.T) {
})
}
}

func TestComplianceSpec_LoadFromDiskBundle(t *testing.T) {

t.Run("load user specified spec from disk", func(t *testing.T) {
cs, err := spec.GetComplianceSpec(filepath.Join("@testdata", "testcache", "policy", "content", "specs", "compliance", "testspec.yaml"), filepath.Join("testdata", "testcache"))
require.NoError(t, err)
assert.Equal(t, spec.ComplianceSpec{Spec: iacTypes.Spec{
ID: "test-spec-1.2",
Title: "Test Spec",
Description: "This is a test spec",
RelatedResources: []string{
"https://www.google.ca",
},
Version: "1.2",
Controls: []iacTypes.Control{
{
Name: "moar-testing",
Description: "Test needs foo bar baz",
ID: "1.1",
Checks: []iacTypes.SpecCheck{
{ID: "AVD-TEST-1234"},
},
Severity: "LOW",
},
},
}}, cs)
})

t.Run("load user specified spec from disk fails", func(t *testing.T) {
_, err := spec.GetComplianceSpec("@doesnotexist", "does-not-matter")
assert.Contains(t, err.Error(), "error retrieving compliance spec from path")
})

t.Run("bundle does not exist", func(t *testing.T) {
cs, err := spec.GetComplianceSpec("aws-cis-1.2", "does-not-matter")
require.NoError(t, err)
assert.Equal(t, "aws-cis-1.2", cs.Spec.ID)
})

t.Run("load spec from disk", func(t *testing.T) {
cs, err := spec.GetComplianceSpec("testspec", filepath.Join("testdata", "testcache"))
require.NoError(t, err)
assert.Equal(t, "test-spec-1.2", cs.Spec.ID)
})

t.Run("load spec yaml unmarshal failure", func(t *testing.T) {
_, err := spec.GetComplianceSpec("invalid", filepath.Join("testdata", "testcache"))
assert.Contains(t, err.Error(), "spec yaml decode error")
})
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
this is not yaml but easier to read
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
spec:
id: test-spec-1.2
title: Test Spec
description: This is a test spec
version: "1.2"
relatedResources:
- https://www.google.ca
controls:
- id: "1.1"
name: moar-testing
description: |-
Test needs foo bar baz
checks:
- id: AVD-TEST-1234
severity: LOW
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"Digest":"sha256:ef2d9ad4fce0f933b20a662004d7e55bf200987c180e7f2cd531af631f408bb3","DownloadedAt":"2024-08-07T20:07:48.917915-06:00"}
3 changes: 2 additions & 1 deletion pkg/flag/report_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"golang.org/x/xerrors"

dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
"github.com/aquasecurity/trivy/pkg/cache"
"github.com/aquasecurity/trivy/pkg/compliance/spec"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/result"
Expand Down Expand Up @@ -260,7 +261,7 @@ func loadComplianceTypes(compliance string) (spec.ComplianceSpec, error) {
return spec.ComplianceSpec{}, xerrors.Errorf("unknown compliance : %v", compliance)
}

cs, err := spec.GetComplianceSpec(compliance)
cs, err := spec.GetComplianceSpec(compliance, cache.DefaultDir())
if err != nil {
return spec.ComplianceSpec{}, xerrors.Errorf("spec loading from file system error: %w", err)
}
Expand Down

0 comments on commit be86126

Please sign in to comment.