From 526c39c1df8c767c06e2e5189027a4644b53c627 Mon Sep 17 00:00:00 2001 From: Johnny Steenbergen Date: Sat, 13 Jan 2024 18:35:12 -0600 Subject: [PATCH] chore(allsrv): replace `http.DefaultServeMux` with isolated `*http.ServeMux` dependency This resolves the panic adding routes with the same pattern multiple times. Now each `Server`, has its own `*http.ServeMux`. Now tests run independent of one another and we avoid the pain of GLOBALS! The tests should now pass :-) --- allsrv/server.go | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/allsrv/server.go b/allsrv/server.go index 9797923..cf3f92f 100644 --- a/allsrv/server.go +++ b/allsrv/server.go @@ -5,7 +5,7 @@ import ( "errors" "log" "net/http" - + "github.com/gofrs/uuid" ) @@ -18,7 +18,7 @@ import ( a) what happens if we forget that copy pasta? 3) auth is hardcoded to basic auth a) what happens if we want to adapt some other means of auth? - 4) router being used is the GLOBAL http.DefaultServeMux + ✅4) router being used is the GLOBAL http.DefaultServeMux a) should avoid globals b) what happens if you have multiple servers in this go module who reference default serve mux? 5) no tests @@ -45,10 +45,11 @@ import ( */ type Server struct { - db *InmemDB // 1) - + db *InmemDB // 1) + mux *http.ServeMux // 4) + user, pass string // 3) - + idFn func() string // 11) } @@ -62,6 +63,7 @@ func WithIDFn(fn func() string) func(*Server) { func NewServer(db *InmemDB, user, pass string, opts ...func(*Server)) *Server { s := Server{ db: db, + mux: http.NewServeMux(), // 4) user: user, pass: pass, idFn: func() string { @@ -72,22 +74,22 @@ func NewServer(db *InmemDB, user, pass string, opts ...func(*Server)) *Server { for _, o := range opts { o(&s) } - + s.routes() return &s } func (s *Server) routes() { // 4) 7) 9) 10) - http.Handle("POST /foo", http.HandlerFunc(s.createFoo)) - http.Handle("GET /foo", http.HandlerFunc(s.readFoo)) - http.Handle("PUT /foo", http.HandlerFunc(s.updateFoo)) - http.Handle("DELETE /foo", http.HandlerFunc(s.delFoo)) + s.mux.Handle("POST /foo", http.HandlerFunc(s.createFoo)) + s.mux.Handle("GET /foo", http.HandlerFunc(s.readFoo)) + s.mux.Handle("PUT /foo", http.HandlerFunc(s.updateFoo)) + s.mux.Handle("DELETE /foo", http.HandlerFunc(s.delFoo)) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 4) - http.DefaultServeMux.ServeHTTP(w, r) + s.mux.ServeHTTP(w, r) } type Foo struct { @@ -103,20 +105,20 @@ func (s *Server) createFoo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) // 9) return } - + var f Foo if err := json.NewDecoder(r.Body).Decode(&f); err != nil { w.WriteHeader(http.StatusForbidden) // 9) return } - + f.ID = s.idFn() // 11) - + if err := s.db.CreateFoo(f); err != nil { w.WriteHeader(http.StatusInternalServerError) // 9) return } - + w.WriteHeader(http.StatusCreated) if err := json.NewEncoder(w).Encode(f); err != nil { log.Printf("unexpected error writing json value to response body: " + err.Error()) // 8) 10) @@ -129,13 +131,13 @@ func (s *Server) readFoo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) // 9) return } - + f, err := s.db.readFoo(r.URL.Query().Get("id")) if err != nil { w.WriteHeader(http.StatusNotFound) // 9) return } - + if err := json.NewEncoder(w).Encode(f); err != nil { log.Printf("unexpected error writing json value to response body: " + err.Error()) // 8) 10) } @@ -147,13 +149,13 @@ func (s *Server) updateFoo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) // 9) return } - + var f Foo if err := json.NewDecoder(r.Body).Decode(&f); err != nil { w.WriteHeader(http.StatusForbidden) // 9) return } - + if err := s.db.updateFoo(f); err != nil { w.WriteHeader(http.StatusInternalServerError) // 9) return @@ -166,7 +168,7 @@ func (s *Server) delFoo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) // 9) return } - + if err := s.db.delFoo(r.URL.Query().Get("id")); err != nil { w.WriteHeader(http.StatusNotFound) // 9) return @@ -184,9 +186,9 @@ func (db *InmemDB) CreateFoo(f Foo) error { return errors.New("foo " + f.Name + " exists") // 8) } } - + db.m = append(db.m, f) - + return nil }