Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(detector): standardize db.NewDB to db.CloseDB #1380

Merged
merged 8 commits into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 79 additions & 80 deletions detector/cve_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,98 +22,95 @@ import (
)

type goCveDictClient struct {
cnf config.VulnDictInterface
driver cvedb.DB
driver cvedb.DB
baseURL string
}

func newGoCveDictClient(cnf config.VulnDictInterface, o logging.LogOpts) (*goCveDictClient, error) {
if err := cvelog.SetLogger(o.LogToFile, o.LogDir, o.Debug, o.LogJSON); err != nil {
return nil, err
return nil, xerrors.Errorf("Failed to set go-cve-dictionary logger. err: %w", err)
}

driver, locked, err := newCveDB(cnf)
if locked {
return nil, xerrors.Errorf("SQLite3 is locked: %s", cnf.GetSQLite3Path())
} else if err != nil {
return nil, err
driver, err := newCveDB(cnf)
if err != nil {
return nil, xerrors.Errorf("Failed to newCveDB. err: %w", err)
}
return &goCveDictClient{cnf: cnf, driver: driver}, nil
return &goCveDictClient{driver: driver, baseURL: cnf.GetURL()}, nil
}

func (api goCveDictClient) closeDB() error {
if api.driver == nil {
func (client goCveDictClient) closeDB() error {
if client.driver == nil {
return nil
}
return api.driver.CloseDB()
}

func (api goCveDictClient) fetchCveDetails(cveIDs []string) (cveDetails []cvemodels.CveDetail, err error) {
m, err := api.driver.GetMulti(cveIDs)
if err != nil {
return nil, xerrors.Errorf("Failed to GetMulti. err: %w", err)
}
for _, v := range m {
cveDetails = append(cveDetails, v)
}
return cveDetails, nil
return client.driver.CloseDB()
}

type response struct {
Key string
CveDetail cvemodels.CveDetail
}

func (api goCveDictClient) fetchCveDetailsViaHTTP(cveIDs []string) (cveDetails []cvemodels.CveDetail, err error) {
reqChan := make(chan string, len(cveIDs))
resChan := make(chan response, len(cveIDs))
errChan := make(chan error, len(cveIDs))
defer close(reqChan)
defer close(resChan)
defer close(errChan)

go func() {
for _, cveID := range cveIDs {
reqChan <- cveID
func (client goCveDictClient) fetchCveDetails(cveIDs []string) (cveDetails []cvemodels.CveDetail, err error) {
if client.driver == nil {
reqChan := make(chan string, len(cveIDs))
resChan := make(chan response, len(cveIDs))
errChan := make(chan error, len(cveIDs))
defer close(reqChan)
defer close(resChan)
defer close(errChan)

go func() {
for _, cveID := range cveIDs {
reqChan <- cveID
}
}()

concurrency := 10
tasks := util.GenWorkers(concurrency)
for range cveIDs {
tasks <- func() {
select {
case cveID := <-reqChan:
url, err := util.URLPathJoin(client.baseURL, "cves", cveID)
if err != nil {
errChan <- err
} else {
logging.Log.Debugf("HTTP Request to %s", url)
httpGet(cveID, url, resChan, errChan)
}
}
}
}
}()

concurrency := 10
tasks := util.GenWorkers(concurrency)
for range cveIDs {
tasks <- func() {
timeout := time.After(2 * 60 * time.Second)
var errs []error
for range cveIDs {
select {
case cveID := <-reqChan:
url, err := util.URLPathJoin(api.cnf.GetURL(), "cves", cveID)
if err != nil {
errChan <- err
} else {
logging.Log.Debugf("HTTP Request to %s", url)
api.httpGet(cveID, url, resChan, errChan)
}
case res := <-resChan:
cveDetails = append(cveDetails, res.CveDetail)
case err := <-errChan:
errs = append(errs, err)
case <-timeout:
return nil, xerrors.New("Timeout Fetching CVE")
}
}
}

timeout := time.After(2 * 60 * time.Second)
var errs []error
for range cveIDs {
select {
case res := <-resChan:
cveDetails = append(cveDetails, res.CveDetail)
case err := <-errChan:
errs = append(errs, err)
case <-timeout:
return nil, xerrors.New("Timeout Fetching CVE")
if len(errs) != 0 {
return nil,
xerrors.Errorf("Failed to fetch CVE. err: %w", errs)
}
} else {
m, err := client.driver.GetMulti(cveIDs)
if err != nil {
return nil, xerrors.Errorf("Failed to GetMulti. err: %w", err)
}
for _, v := range m {
cveDetails = append(cveDetails, v)
}
}
if len(errs) != 0 {
return nil,
xerrors.Errorf("Failed to fetch CVE. err: %w", errs)
}
return
return cveDetails, nil
}

func (api goCveDictClient) httpGet(key, url string, resChan chan<- response, errChan chan<- error) {
func httpGet(key, url string, resChan chan<- response, errChan chan<- error) {
var body string
var errs []error
var resp *http.Response
Expand Down Expand Up @@ -144,21 +141,21 @@ func (api goCveDictClient) httpGet(key, url string, resChan chan<- response, err
}
}

func (api goCveDictClient) detectCveByCpeURI(cpeURI string, useJVN bool) (cves []cvemodels.CveDetail, err error) {
if api.cnf.IsFetchViaHTTP() {
url, err := util.URLPathJoin(api.cnf.GetURL(), "cpes")
func (client goCveDictClient) detectCveByCpeURI(cpeURI string, useJVN bool) (cves []cvemodels.CveDetail, err error) {
if client.driver == nil {
url, err := util.URLPathJoin(client.baseURL, "cpes")
if err != nil {
return nil, err
return nil, xerrors.Errorf("Failed to join URLPath. err: %w", err)
}

query := map[string]string{"name": cpeURI}
logging.Log.Debugf("HTTP Request to %s, query: %#v", url, query)
if cves, err = api.httpPost(url, query); err != nil {
return nil, err
if cves, err = httpPost(url, query); err != nil {
return nil, xerrors.Errorf("Failed to post HTTP Request. err: %w", err)
}
} else {
if cves, err = api.driver.GetByCpeURI(cpeURI); err != nil {
return nil, err
if cves, err = client.driver.GetByCpeURI(cpeURI); err != nil {
return nil, xerrors.Errorf("Failed to get CVEs by CPEURI. err: %w", err)
}
}

Expand All @@ -177,7 +174,7 @@ func (api goCveDictClient) detectCveByCpeURI(cpeURI string, useJVN bool) (cves [
return nvdCves, nil
}

func (api goCveDictClient) httpPost(url string, query map[string]string) ([]cvemodels.CveDetail, error) {
func httpPost(url string, query map[string]string) ([]cvemodels.CveDetail, error) {
var body string
var errs []error
var resp *http.Response
Expand Down Expand Up @@ -208,18 +205,20 @@ func (api goCveDictClient) httpPost(url string, query map[string]string) ([]cvem
return cveDetails, nil
}

func newCveDB(cnf config.VulnDictInterface) (driver cvedb.DB, locked bool, err error) {
func newCveDB(cnf config.VulnDictInterface) (cvedb.DB, error) {
if cnf.IsFetchViaHTTP() {
return nil, false, nil
return nil, nil
}
path := cnf.GetURL()
if cnf.GetType() == "sqlite3" {
path = cnf.GetSQLite3Path()
}
driver, locked, err = cvedb.NewDB(cnf.GetType(), path, cnf.GetDebugSQL(), cvedb.Option{})
driver, locked, err := cvedb.NewDB(cnf.GetType(), path, cnf.GetDebugSQL(), cvedb.Option{})
if err != nil {
err = xerrors.Errorf("Failed to init CVE DB. err: %w, path: %s", err, path)
return nil, locked, err
if locked {
return nil, xerrors.Errorf("Failed to init CVE DB. SQLite3: %s is locked. err: %w", cnf.GetSQLite3Path(), err)
}
return nil, xerrors.Errorf("Failed to init CVE DB. DB Path: %s, err: %w", path, err)
}
return driver, false, nil
return driver, nil
}
Loading