diff --git a/internal/resolution/datasource/maven_test.go b/internal/resolution/datasource/maven_test.go index 2bc82efeeb..5c665dfbf7 100644 --- a/internal/resolution/datasource/maven_test.go +++ b/internal/resolution/datasource/maven_test.go @@ -2,62 +2,28 @@ package datasource import ( "context" - "io" - "log" - "net/http" - "net/http/httptest" "reflect" - "strings" - "sync" "testing" "deps.dev/util/maven" + "github.com/google/osv-scanner/internal/testutility" ) -type fakeMavenRegistry struct { - mu sync.Mutex - repository map[string]string // path -> response -} - -func (f *fakeMavenRegistry) setResponse(path, response string) { - f.mu.Lock() - defer f.mu.Unlock() - if f.repository == nil { - f.repository = make(map[string]string) - } - f.repository[path] = response -} - -func (f *fakeMavenRegistry) ServeHTTP(w http.ResponseWriter, r *http.Request) { - f.mu.Lock() - resp, ok := f.repository[strings.TrimPrefix(r.URL.Path, "/")] - f.mu.Unlock() - if !ok { - w.WriteHeader(http.StatusNotFound) - resp = "not found" - } - if _, err := io.WriteString(w, resp); err != nil { - log.Fatalf("WriteString: %v", err) - } -} - func TestGetProject(t *testing.T) { t.Parallel() - fakeMaven := &fakeMavenRegistry{} - srv := httptest.NewServer(fakeMaven) - defer srv.Close() + srv := testutility.NewMockHTTPServer(t) client := &MavenRegistryAPIClient{ Registry: srv.URL, } - fakeMaven.setResponse("org/example/x.y.z/1.0.0/x.y.z-1.0.0.pom", ` + srv.SetResponse(t, "org/example/x.y.z/1.0.0/x.y.z-1.0.0.pom", []byte(` org.example x.y.z 1.0.0 - `) + `)) got, err := client.GetProject(context.Background(), "org.example", "x.y.z", "1.0.0") if err != nil { t.Fatalf("failed to get Maven project %s:%s verion %s: %v", "org.example", "x.y.z", "1.0.0", err) diff --git a/internal/resolution/datasource/npm_registry_test.go b/internal/resolution/datasource/npm_registry_test.go index b991937768..81baa7c880 100644 --- a/internal/resolution/datasource/npm_registry_test.go +++ b/internal/resolution/datasource/npm_registry_test.go @@ -2,64 +2,15 @@ package datasource_test import ( "context" - "log" - "net/http" - "net/http/httptest" - "os" "path/filepath" "slices" "strings" - "sync" "testing" "github.com/google/osv-scanner/internal/resolution/datasource" "github.com/google/osv-scanner/internal/testutility" ) -type fakeNpmRegistry struct { - mu sync.Mutex - repository map[string][]byte // path -> response - expectAuth string -} - -func (f *fakeNpmRegistry) setResponse(path string, response []byte) { - f.mu.Lock() - defer f.mu.Unlock() - if f.repository == nil { - f.repository = make(map[string][]byte) - } - f.repository[path] = response -} - -func (f *fakeNpmRegistry) setExpectAuth(auth string) { - f.mu.Lock() - defer f.mu.Unlock() - f.expectAuth = auth -} - -func (f *fakeNpmRegistry) ServeHTTP(w http.ResponseWriter, r *http.Request) { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path - } - f.mu.Lock() - wantAuth := f.expectAuth - resp, ok := f.repository[strings.TrimPrefix(path, "/")] - f.mu.Unlock() - - if wantAuth != "" && r.Header.Get("Authorization") != wantAuth { - w.WriteHeader(http.StatusUnauthorized) - resp = []byte("unauthorized") - } else if !ok { - w.WriteHeader(http.StatusNotFound) - resp = []byte("not found") - } - - if _, err := w.Write(resp); err != nil { - log.Fatalf("Write: %v", err) - } -} - func TestNpmRegistryClient(t *testing.T) { t.Parallel() @@ -69,34 +20,14 @@ func TestNpmRegistryClient(t *testing.T) { authToken = "bmljZS10b2tlbgo=" ) - reg1 := &fakeNpmRegistry{} - reg1.setExpectAuth("Basic " + auth) - - b, err := os.ReadFile("./fixtures/npm_registry/fake-package.json") - if err != nil { - t.Fatalf("failed to read fake registry response file: %v", err) - } - reg1.setResponse("fake-package", b) - - b, err = os.ReadFile("./fixtures/npm_registry/fake-package-2.2.2.json") - if err != nil { - t.Fatalf("failed to read fake registry response file: %v", err) - } - reg1.setResponse("fake-package/2.2.2", b) - - reg2 := &fakeNpmRegistry{} - reg2.setExpectAuth("Bearer " + authToken) - - b, err = os.ReadFile("./fixtures/npm_registry/@fake-registry-a.json") - if err != nil { - t.Fatalf("failed to read fake registry response file: %v", err) - } - reg2.setResponse("@fake-registry%2Fa", b) + srv1 := testutility.NewMockHTTPServer(t) + srv1.SetAuthorization(t, "Basic "+auth) + srv1.SetResponseFromFile(t, "/fake-package", "./fixtures/npm_registry/fake-package.json") + srv1.SetResponseFromFile(t, "/fake-package/2.2.2", "./fixtures/npm_registry/fake-package-2.2.2.json") - srv1 := httptest.NewServer(reg1) - defer srv1.Close() - srv2 := httptest.NewServer(reg2) - defer srv2.Close() + srv2 := testutility.NewMockHTTPServer(t) + srv2.SetAuthorization(t, "Bearer "+authToken) + srv2.SetResponseFromFile(t, "/@fake-registry%2Fa", "./fixtures/npm_registry/@fake-registry-a.json") npmrcFile := createTempNpmrc(t, ".npmrc") writeToNpmrc(t, npmrcFile, diff --git a/internal/testutility/mock_http.go b/internal/testutility/mock_http.go new file mode 100644 index 0000000000..4350946ee2 --- /dev/null +++ b/internal/testutility/mock_http.go @@ -0,0 +1,85 @@ +package testutility + +import ( + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "testing" +) + +type MockHTTPServer struct { + *httptest.Server + mu sync.Mutex + response map[string][]byte // path -> response + authorization string // expected Authorization header contents +} + +// NewMockHTTPServer starts and returns a new simple HTTP Server for mocking basic requests. +// The Server will automatically be shut down with Close() in the test Cleanup function. +// +// Use the SetResponse / SetResponseFromFile to set the responses for specific URL paths. +func NewMockHTTPServer(t *testing.T) *MockHTTPServer { + t.Helper() + mock := &MockHTTPServer{response: make(map[string][]byte)} + mock.Server = httptest.NewServer(mock) + t.Cleanup(func() { mock.Server.Close() }) + + return mock +} + +// SetResponse sets the Server's response for the URL path to be response bytes. +func (m *MockHTTPServer) SetResponse(t *testing.T, path string, response []byte) { + t.Helper() + m.mu.Lock() + defer m.mu.Unlock() + path = strings.TrimPrefix(path, "/") + m.response[path] = response +} + +// SetResponseFromFile sets the Server's response for the URL path to be the contents of the file at filename. +func (m *MockHTTPServer) SetResponseFromFile(t *testing.T, path string, filename string) { + t.Helper() + b, err := os.ReadFile(filename) + if err != nil { + t.Fatalf("failed to read response file: %v", err) + } + m.SetResponse(t, path, b) +} + +// SetAuthorization sets the contents of the 'Authorization' header the server expects for all endpoints. +// +// The incoming requests' headers must match the auth string exactly, otherwise the server will response with 401 Unauthorized. +// If authorization is unset or empty, the server will not require authorization. +func (m *MockHTTPServer) SetAuthorization(t *testing.T, auth string) { + t.Helper() + m.mu.Lock() + defer m.mu.Unlock() + m.authorization = auth +} + +// ServeHTTP is the http.Handler for the underlying httptest.Server. +func (m *MockHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.RawPath + if path == "" { + path = r.URL.Path + } + m.mu.Lock() + wantAuth := m.authorization + resp, ok := m.response[strings.TrimPrefix(path, "/")] + m.mu.Unlock() + + if wantAuth != "" && r.Header.Get("Authorization") != wantAuth { + w.WriteHeader(http.StatusUnauthorized) + resp = []byte("unauthorized") + } else if !ok { + w.WriteHeader(http.StatusNotFound) + resp = []byte("not found") + } + + if _, err := w.Write(resp); err != nil { + log.Fatalf("Write: %v", err) + } +}