Skip to content

Commit

Permalink
chore(allsrv): inject authorization mechanism to decouple it from server
Browse files Browse the repository at this point in the history
This does a few things:

  1) Allows us to control the auth at setup without having to update
     the implementation of the server endpoints. We've effectively decoupled
     our auth from the server, which gives us freedom to adapt to future
     asks.
  2) The injection is using a `middleware` function. This could be
     an interface as well. Its totally up to the developer/team. Sometimes
     an interface is more useful.
  3) We have the freedom to ignore auth in tests if we so desire. This
     can be useful if your auth setup is non-trivial and involves a good
     bit of complexity.
  • Loading branch information
jsteenb2 committed Jul 5, 2024
1 parent 5030afa commit 1a259d6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
33 changes: 22 additions & 11 deletions allsrv/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
a) what happens if we want a different db?
✅2) auth is copy-pasted in each handler
a) what happens if we forget that copy pasta?
3) auth is hardcoded to basic auth
3) auth is hardcoded to basic auth
a) what happens if we want to adapt some other means of auth?
✅4) router being used is the GLOBAL http.DefaultServeMux
a) should avoid globals
Expand Down Expand Up @@ -49,9 +49,16 @@ type Server struct {
db *InmemDB // 1)
mux *http.ServeMux // 4)

user, pass string // 3)
authFn func(http.Handler) http.Handler // 3)
idFn func() string // 11)
}

idFn func() string // 11)
// WithBasicAuth sets the authorization fn for the server to basic auth.
// 3)
func WithBasicAuth(user, pass string) func(*Server) {
return func(s *Server) {
s.authFn = basicAuth(user, pass)
}
}

// WithIDFn sets the id generation fn for the server.
Expand All @@ -61,12 +68,16 @@ func WithIDFn(fn func() string) func(*Server) {
}
}

func NewServer(db *InmemDB, user, pass string, opts ...func(*Server)) *Server {
func NewServer(db *InmemDB, opts ...func(*Server)) *Server {
s := Server{
db: db,
mux: http.NewServeMux(), // 4)
user: user,
pass: pass,
db: db,
mux: http.NewServeMux(), // 4)
authFn: func(next http.Handler) http.Handler { // 3)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// defaults to no auth
next.ServeHTTP(w, r)
})
},
idFn: func() string {
// defaults to using a uuid
return uuid.Must(uuid.NewV4()).String()
Expand All @@ -81,7 +92,7 @@ func NewServer(db *InmemDB, user, pass string, opts ...func(*Server)) *Server {
}

func (s *Server) routes() {
authMW := BasicAuth(s.user, s.pass) // 2)
authMW := s.authFn // 2)

// 4) 7) 9) 10)
s.mux.Handle("POST /foo", authMW(http.HandlerFunc(s.createFoo)))
Expand Down Expand Up @@ -154,9 +165,9 @@ func (s *Server) delFoo(w http.ResponseWriter, r *http.Request) {
}
}

// BasicAuth provides a basic auth middleware to an http server.
// basicAuth provides a basic auth middleware to an http server.
// 2)
func BasicAuth(expectedUser, expectedPass string) func(http.Handler) http.Handler {
func basicAuth(expectedUser, expectedPass string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if user, pass, ok := r.BasicAuth(); !(ok && user == expectedUser && pass == expectedPass) {
Expand Down
23 changes: 13 additions & 10 deletions allsrv/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ func TestServer(t *testing.T) {
t.Run("foo create", func(t *testing.T) {
t.Run("when provided a valid foo should pass", func(t *testing.T) {
db := new(allsrv.InmemDB)
svr := allsrv.NewServer(db, "dodgers@stink.com", "PaSsWoRd", allsrv.WithIDFn(func() string {
return "id1"
}))
svr := allsrv.NewServer(db,
allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"),
allsrv.WithIDFn(func() string {
return "id1"
}),
)

req := httptest.NewRequest("POST", "/foo", newJSONBody(t, allsrv.Foo{
Name: "first-foo",
Expand All @@ -43,7 +46,7 @@ func TestServer(t *testing.T) {
})

t.Run("when provided invalid basic auth should fail", func(t *testing.T) {
svr := allsrv.NewServer(new(allsrv.InmemDB), "dodgers@stink.com", "PaSsWoRd")
svr := allsrv.NewServer(new(allsrv.InmemDB), allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"))

req := httptest.NewRequest("POST", "/foo", newJSONBody(t, allsrv.Foo{
Name: "first-foo",
Expand All @@ -68,7 +71,7 @@ func TestServer(t *testing.T) {
})
require.NoError(t, err)

svr := allsrv.NewServer(db, "dodgers@stink.com", "PaSsWoRd")
svr := allsrv.NewServer(db, allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"))

req := httptest.NewRequest("GET", "/foo?id=reader1", nil)
req.SetBasicAuth("dodgers@stink.com", "PaSsWoRd")
Expand All @@ -88,7 +91,7 @@ func TestServer(t *testing.T) {
})

t.Run("when provided invalid basic auth should fail", func(t *testing.T) {
svr := allsrv.NewServer(new(allsrv.InmemDB), "dodgers@stink.com", "PaSsWoRd")
svr := allsrv.NewServer(new(allsrv.InmemDB), allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"))

req := httptest.NewRequest("GET", "/foo?id=reader1", nil)
req.SetBasicAuth("dodgers@rule.com", "wrongO")
Expand All @@ -110,7 +113,7 @@ func TestServer(t *testing.T) {
})
require.NoError(t, err)

svr := allsrv.NewServer(db, "dodgers@stink.com", "PaSsWoRd")
svr := allsrv.NewServer(db, allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"))

req := httptest.NewRequest("PUT", "/foo", newJSONBody(t, allsrv.Foo{
ID: "id1",
Expand All @@ -135,7 +138,7 @@ func TestServer(t *testing.T) {
})
require.NoError(t, err)

svr := allsrv.NewServer(db, "dodgers@stink.com", "PaSsWoRd")
svr := allsrv.NewServer(db, allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"))

req := httptest.NewRequest("PUT", "/foo", newJSONBody(t, allsrv.Foo{
ID: "id1",
Expand All @@ -161,7 +164,7 @@ func TestServer(t *testing.T) {
})
require.NoError(t, err)

svr := allsrv.NewServer(db, "dodgers@stink.com", "PaSsWoRd")
svr := allsrv.NewServer(db, allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"))

req := httptest.NewRequest("DELETE", "/foo?id=id1", nil)
req.SetBasicAuth("dodgers@stink.com", "PaSsWoRd")
Expand All @@ -173,7 +176,7 @@ func TestServer(t *testing.T) {
})

t.Run("when provided invalid basic auth should fail", func(t *testing.T) {
svr := allsrv.NewServer(new(allsrv.InmemDB), "dodgers@stink.com", "PaSsWoRd")
svr := allsrv.NewServer(new(allsrv.InmemDB), allsrv.WithBasicAuth("dodgers@stink.com", "PaSsWoRd"))

req := httptest.NewRequest("DELETE", "/foo?id=id1", nil)
req.SetBasicAuth("dodgers@rule.com", "wrongO")
Expand Down

0 comments on commit 1a259d6

Please sign in to comment.