diff --git a/handlers.go b/handlers.go index d322043..e73efc1 100644 --- a/handlers.go +++ b/handlers.go @@ -54,6 +54,7 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { defer s.clientsMu.Unlock() s.clients[conn] = struct{}{} ticker := time.NewTicker(pingPeriod) + stop := make(chan struct{}) // NIP-42 challenge challenge := make([]byte, 8) @@ -75,6 +76,8 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { go func() { defer func() { ticker.Stop() + stop <- struct{}{} + close(stop) s.clientsMu.Lock() if _, ok := s.clients[conn]; ok { conn.Close() @@ -113,7 +116,7 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { if ws.limiter != nil { // NOTE: Wait will throttle the requests. // To reject requests exceeding the limit, use if !ws.limiter.Allow() - if err := ws.limiter.Wait(context.Background()); err != nil { + if err := ws.limiter.Wait(context.TODO()); err != nil { s.Log.Warningf("unexpected limiter error %v", err) continue } @@ -125,7 +128,8 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } go func(message []byte) { - ctx = context.Background() + ctx = context.TODO() + var notice string defer func() { if notice != "" { @@ -377,11 +381,14 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) { for { select { case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(writeWait)) err := ws.WriteMessage(websocket.PingMessage, nil) if err != nil { s.Log.Errorf("error writing ping: %v; closing websocket", err) return } + case <-stop: + return } } }()