diff --git a/allsrv/server.go b/allsrv/server.go index fb92041..947fbbf 100644 --- a/allsrv/server.go +++ b/allsrv/server.go @@ -14,7 +14,7 @@ import ( Concerns: 1) the server depends on a hard type, coupling to the exact inmem db a) what happens if we want a different db? - 2) auth is copy-pasted in each handler + ✅2) auth is copy-pasted in each handler a) what happens if we forget that copy pasta? 3) auth is hardcoded to basic auth a) what happens if we want to adapt some other means of auth? @@ -81,11 +81,13 @@ func NewServer(db *InmemDB, user, pass string, opts ...func(*Server)) *Server { } func (s *Server) routes() { + authMW := BasicAuth(s.user, s.pass) // 2) + // 4) 7) 9) 10) - s.mux.Handle("POST /foo", http.HandlerFunc(s.createFoo)) - s.mux.Handle("GET /foo", http.HandlerFunc(s.readFoo)) - s.mux.Handle("PUT /foo", http.HandlerFunc(s.updateFoo)) - s.mux.Handle("DELETE /foo", http.HandlerFunc(s.delFoo)) + s.mux.Handle("POST /foo", authMW(http.HandlerFunc(s.createFoo))) + s.mux.Handle("GET /foo", authMW(http.HandlerFunc(s.readFoo))) + s.mux.Handle("PUT /foo", authMW(http.HandlerFunc(s.updateFoo))) + s.mux.Handle("DELETE /foo", authMW(http.HandlerFunc(s.delFoo))) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -101,12 +103,6 @@ type Foo struct { } func (s *Server) createFoo(w http.ResponseWriter, r *http.Request) { - // 2) - if user, pass, ok := r.BasicAuth(); !(ok && user == s.user && pass == s.pass) { - w.WriteHeader(http.StatusUnauthorized) // 9) - return - } - var f Foo if err := json.NewDecoder(r.Body).Decode(&f); err != nil { w.WriteHeader(http.StatusForbidden) // 9) @@ -127,12 +123,6 @@ func (s *Server) createFoo(w http.ResponseWriter, r *http.Request) { } func (s *Server) readFoo(w http.ResponseWriter, r *http.Request) { - // 2) - if user, pass, ok := r.BasicAuth(); !(ok && user == s.user && pass == s.pass) { - w.WriteHeader(http.StatusUnauthorized) // 9) - return - } - f, err := s.db.readFoo(r.URL.Query().Get("id")) if err != nil { w.WriteHeader(http.StatusNotFound) // 9) @@ -145,12 +135,6 @@ func (s *Server) readFoo(w http.ResponseWriter, r *http.Request) { } func (s *Server) updateFoo(w http.ResponseWriter, r *http.Request) { - // 2) - if user, pass, ok := r.BasicAuth(); !(ok && user == s.user && pass == s.pass) { - w.WriteHeader(http.StatusUnauthorized) // 9) - return - } - var f Foo if err := json.NewDecoder(r.Body).Decode(&f); err != nil { w.WriteHeader(http.StatusForbidden) // 9) @@ -164,18 +148,26 @@ func (s *Server) updateFoo(w http.ResponseWriter, r *http.Request) { } func (s *Server) delFoo(w http.ResponseWriter, r *http.Request) { - // 2) - if user, pass, ok := r.BasicAuth(); !(ok && user == s.user && pass == s.pass) { - w.WriteHeader(http.StatusUnauthorized) // 9) - return - } - if err := s.db.delFoo(r.URL.Query().Get("id")); err != nil { w.WriteHeader(http.StatusNotFound) // 9) return } } +// BasicAuth provides a basic auth middleware to an http server. +// 2) +func BasicAuth(expectedUser, expectedPass string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if user, pass, ok := r.BasicAuth(); !(ok && user == expectedUser && pass == expectedPass) { + w.WriteHeader(http.StatusUnauthorized) // 9) + return + } + next.ServeHTTP(w, r) + }) + } +} + // InmemDB is an in-memory store. type InmemDB struct { m []Foo // 12)