Skip to content

Commit

Permalink
Make MockHTTPServer for tests (google#888)
Browse files Browse the repository at this point in the history
Factored out the httptest.Server / handler into one struct for easier
reuse in tests.
  • Loading branch information
michaelkedar authored Mar 26, 2024
1 parent f97e5d1 commit 1c01916
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 114 deletions.
42 changes: 4 additions & 38 deletions internal/resolution/datasource/maven_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,62 +2,28 @@ package datasource

import (
"context"
"io"
"log"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"sync"
"testing"

"deps.dev/util/maven"
"github.com/google/osv-scanner/internal/testutility"
)

type fakeMavenRegistry struct {
mu sync.Mutex
repository map[string]string // path -> response
}

func (f *fakeMavenRegistry) setResponse(path, response string) {
f.mu.Lock()
defer f.mu.Unlock()
if f.repository == nil {
f.repository = make(map[string]string)
}
f.repository[path] = response
}

func (f *fakeMavenRegistry) ServeHTTP(w http.ResponseWriter, r *http.Request) {
f.mu.Lock()
resp, ok := f.repository[strings.TrimPrefix(r.URL.Path, "/")]
f.mu.Unlock()
if !ok {
w.WriteHeader(http.StatusNotFound)
resp = "not found"
}
if _, err := io.WriteString(w, resp); err != nil {
log.Fatalf("WriteString: %v", err)
}
}

func TestGetProject(t *testing.T) {
t.Parallel()

fakeMaven := &fakeMavenRegistry{}
srv := httptest.NewServer(fakeMaven)
defer srv.Close()
srv := testutility.NewMockHTTPServer(t)
client := &MavenRegistryAPIClient{
Registry: srv.URL,
}

fakeMaven.setResponse("org/example/x.y.z/1.0.0/x.y.z-1.0.0.pom", `
srv.SetResponse(t, "org/example/x.y.z/1.0.0/x.y.z-1.0.0.pom", []byte(`
<project>
<groupId>org.example</groupId>
<artifactId>x.y.z</artifactId>
<version>1.0.0</version>
</project>
`)
`))
got, err := client.GetProject(context.Background(), "org.example", "x.y.z", "1.0.0")
if err != nil {
t.Fatalf("failed to get Maven project %s:%s verion %s: %v", "org.example", "x.y.z", "1.0.0", err)
Expand Down
83 changes: 7 additions & 76 deletions internal/resolution/datasource/npm_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,15 @@ package datasource_test

import (
"context"
"log"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"testing"

"github.com/google/osv-scanner/internal/resolution/datasource"
"github.com/google/osv-scanner/internal/testutility"
)

type fakeNpmRegistry struct {
mu sync.Mutex
repository map[string][]byte // path -> response
expectAuth string
}

func (f *fakeNpmRegistry) setResponse(path string, response []byte) {
f.mu.Lock()
defer f.mu.Unlock()
if f.repository == nil {
f.repository = make(map[string][]byte)
}
f.repository[path] = response
}

func (f *fakeNpmRegistry) setExpectAuth(auth string) {
f.mu.Lock()
defer f.mu.Unlock()
f.expectAuth = auth
}

func (f *fakeNpmRegistry) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.RawPath
if path == "" {
path = r.URL.Path
}
f.mu.Lock()
wantAuth := f.expectAuth
resp, ok := f.repository[strings.TrimPrefix(path, "/")]
f.mu.Unlock()

if wantAuth != "" && r.Header.Get("Authorization") != wantAuth {
w.WriteHeader(http.StatusUnauthorized)
resp = []byte("unauthorized")
} else if !ok {
w.WriteHeader(http.StatusNotFound)
resp = []byte("not found")
}

if _, err := w.Write(resp); err != nil {
log.Fatalf("Write: %v", err)
}
}

func TestNpmRegistryClient(t *testing.T) {
t.Parallel()

Expand All @@ -69,34 +20,14 @@ func TestNpmRegistryClient(t *testing.T) {
authToken = "bmljZS10b2tlbgo="
)

reg1 := &fakeNpmRegistry{}
reg1.setExpectAuth("Basic " + auth)

b, err := os.ReadFile("./fixtures/npm_registry/fake-package.json")
if err != nil {
t.Fatalf("failed to read fake registry response file: %v", err)
}
reg1.setResponse("fake-package", b)

b, err = os.ReadFile("./fixtures/npm_registry/fake-package-2.2.2.json")
if err != nil {
t.Fatalf("failed to read fake registry response file: %v", err)
}
reg1.setResponse("fake-package/2.2.2", b)

reg2 := &fakeNpmRegistry{}
reg2.setExpectAuth("Bearer " + authToken)

b, err = os.ReadFile("./fixtures/npm_registry/@fake-registry-a.json")
if err != nil {
t.Fatalf("failed to read fake registry response file: %v", err)
}
reg2.setResponse("@fake-registry%2Fa", b)
srv1 := testutility.NewMockHTTPServer(t)
srv1.SetAuthorization(t, "Basic "+auth)
srv1.SetResponseFromFile(t, "/fake-package", "./fixtures/npm_registry/fake-package.json")
srv1.SetResponseFromFile(t, "/fake-package/2.2.2", "./fixtures/npm_registry/fake-package-2.2.2.json")

srv1 := httptest.NewServer(reg1)
defer srv1.Close()
srv2 := httptest.NewServer(reg2)
defer srv2.Close()
srv2 := testutility.NewMockHTTPServer(t)
srv2.SetAuthorization(t, "Bearer "+authToken)
srv2.SetResponseFromFile(t, "/@fake-registry%2Fa", "./fixtures/npm_registry/@fake-registry-a.json")

npmrcFile := createTempNpmrc(t, ".npmrc")
writeToNpmrc(t, npmrcFile,
Expand Down
85 changes: 85 additions & 0 deletions internal/testutility/mock_http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package testutility

import (
"log"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
)

type MockHTTPServer struct {
*httptest.Server
mu sync.Mutex
response map[string][]byte // path -> response
authorization string // expected Authorization header contents
}

// NewMockHTTPServer starts and returns a new simple HTTP Server for mocking basic requests.
// The Server will automatically be shut down with Close() in the test Cleanup function.
//
// Use the SetResponse / SetResponseFromFile to set the responses for specific URL paths.
func NewMockHTTPServer(t *testing.T) *MockHTTPServer {
t.Helper()
mock := &MockHTTPServer{response: make(map[string][]byte)}
mock.Server = httptest.NewServer(mock)
t.Cleanup(func() { mock.Server.Close() })

return mock
}

// SetResponse sets the Server's response for the URL path to be response bytes.
func (m *MockHTTPServer) SetResponse(t *testing.T, path string, response []byte) {
t.Helper()
m.mu.Lock()
defer m.mu.Unlock()
path = strings.TrimPrefix(path, "/")
m.response[path] = response
}

// SetResponseFromFile sets the Server's response for the URL path to be the contents of the file at filename.
func (m *MockHTTPServer) SetResponseFromFile(t *testing.T, path string, filename string) {
t.Helper()
b, err := os.ReadFile(filename)
if err != nil {
t.Fatalf("failed to read response file: %v", err)
}
m.SetResponse(t, path, b)
}

// SetAuthorization sets the contents of the 'Authorization' header the server expects for all endpoints.
//
// The incoming requests' headers must match the auth string exactly, otherwise the server will response with 401 Unauthorized.
// If authorization is unset or empty, the server will not require authorization.
func (m *MockHTTPServer) SetAuthorization(t *testing.T, auth string) {
t.Helper()
m.mu.Lock()
defer m.mu.Unlock()
m.authorization = auth
}

// ServeHTTP is the http.Handler for the underlying httptest.Server.
func (m *MockHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.RawPath
if path == "" {
path = r.URL.Path
}
m.mu.Lock()
wantAuth := m.authorization
resp, ok := m.response[strings.TrimPrefix(path, "/")]
m.mu.Unlock()

if wantAuth != "" && r.Header.Get("Authorization") != wantAuth {
w.WriteHeader(http.StatusUnauthorized)
resp = []byte("unauthorized")
} else if !ok {
w.WriteHeader(http.StatusNotFound)
resp = []byte("not found")
}

if _, err := w.Write(resp); err != nil {
log.Fatalf("Write: %v", err)
}
}

0 comments on commit 1c01916

Please sign in to comment.