Skip to content

Commit

Permalink
simplify echo server
Browse files Browse the repository at this point in the history
  • Loading branch information
slytomcat committed Apr 20, 2024
1 parent a89955e commit d7df728
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 30 deletions.
3 changes: 1 addition & 2 deletions echo-server/echo-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ func main() {
case "broadcast":
count := 0
out, _ := json.Marshal(in)
srv.ForEachConnection(func(c *websocket.Conn) bool {
srv.ForEachConnection(func(c *websocket.Conn) {
c.WriteMessage(websocket.TextMessage, out)
count++
return true
})
sendMsg(msg{"broadcastResult", in.Payload, &count}, conn)
default:
Expand Down
37 changes: 19 additions & 18 deletions server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"fmt"
"net/http"
"strings"
Expand All @@ -10,14 +11,6 @@ import (
"github.com/gorilla/websocket"
)

var upgrader = websocket.Upgrader{
HandshakeTimeout: time.Second,
Subprotocols: []string{},
CheckOrigin: func(r *http.Request) bool {
return true
},
}

// Server is a websocket/http server. It is wrapper for standard http.Server with additional functionality for websocket request handling.
type Server struct {
http.Server
Expand All @@ -34,8 +27,14 @@ type Server struct {
func NewServer(addr string) *Server {
mux := http.NewServeMux()
return &Server{
mux: mux,
Upgrader: upgrader,
mux: mux,
Upgrader: websocket.Upgrader{
HandshakeTimeout: time.Second,
Subprotocols: []string{},
CheckOrigin: func(r *http.Request) bool {
return true
},
},
Server: http.Server{
Addr: addr,
Handler: mux,
Expand All @@ -53,16 +52,17 @@ func (s *Server) HandleFunc(path string, handler func(w http.ResponseWriter, r *
s.mux.HandleFunc(path, handler)
}

// Close correctly closes all active ws connections and close the server
// Close correctly closes all active ws connections and shutdown the server
func (s *Server) Close() error {
s.ForEachConnection(func(c *websocket.Conn) bool {
s.ForEachConnection(func(c *websocket.Conn) {
if err := TryCloseNormally(c, "server going down"); err != nil {
fmt.Printf("server: closing connection from %s error: %v\n", c.RemoteAddr(), err)
}
c.Close()
return true
})
return s.Server.Close()
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
return s.Server.Shutdown(ctx)
}

// TryCloseNormally tries to close websocket connection normally i.e. according to RFC
Expand All @@ -80,7 +80,7 @@ func TryCloseNormally(conn *websocket.Conn, message string) error {

func (s *Server) serve(handler func(*websocket.Conn)) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
connection, err := upgrader.Upgrade(w, r, nil)
connection, err := s.Upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Printf("server: connection upgrade error: %v\n", err)
return
Expand All @@ -96,8 +96,9 @@ func (s *Server) serve(handler func(*websocket.Conn)) func(w http.ResponseWriter
}

// ForEachConnection allow to iterate over all active connections
func (s *Server) ForEachConnection(f func(*websocket.Conn) bool) {
s.connections.Range(func(key, _ any) bool {
return f(key.(*websocket.Conn))
func (s *Server) ForEachConnection(f func(*websocket.Conn)) {
s.connections.Range(func(conn, _ any) bool {
f(conn.(*websocket.Conn))
return true
})
}
63 changes: 53 additions & 10 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package server
import (
"context"
"net/http"
"sync"
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -50,15 +52,11 @@ func TestHandshakeServerError(t *testing.T) {
fullURL := "ws://:8080"
s := NewServer(":8080")
require.NotNil(t, s)
s.Upgrader.CheckOrigin = func(r *http.Request) bool { return false }
s.WSHandleFunc("/", EchoHandler)
go func() { s.ListenAndServe() }()
time.Sleep(50 * time.Millisecond)
defer s.Close()
orig := s.Upgrader
upgrader.CheckOrigin = func(r *http.Request) bool { return false }
defer func() {
s.Upgrader = orig
}()
dialer := websocket.Dialer{}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -70,15 +68,11 @@ func TestHandshakeClientError(t *testing.T) {
fullURL := "ws://:8080"
s := NewServer(":8080")
require.NotNil(t, s)
s.Upgrader.CheckOrigin = func(r *http.Request) bool { return false }
s.WSHandleFunc("/", EchoHandler)
go func() { s.ListenAndServe() }()
time.Sleep(50 * time.Millisecond)
defer s.Close()
orig := s.Upgrader
upgrader.CheckOrigin = func(r *http.Request) bool { return false }
defer func() {
s.Upgrader = orig
}()
dialer := websocket.Dialer{HandshakeTimeout: time.Nanosecond}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand All @@ -101,3 +95,52 @@ func TestRegularHandler(t *testing.T) {
require.Equal(t, http.StatusOK, resp.StatusCode)

}

func TestForEachConnection(t *testing.T) {
fullURL := "ws://localhost:8080"
s := NewServer("localhost:8080")
require.NotNil(t, s)
s.WSHandleFunc("/", EchoHandler)
go func() { s.ListenAndServe() }()
time.Sleep(50 * time.Millisecond)
defer s.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sentMsg := []byte("ping")
readers := 7
readCh := make(chan []byte, readers)
wg := sync.WaitGroup{}
wg.Add(readers)
for i := 0; i < readers; i++ {
go func() {
defer wg.Done()
dialer := websocket.Dialer{}
conn, _, err := dialer.DialContext(ctx, fullURL, nil)
require.NoError(t, err)
_, msg, err := conn.ReadMessage()
require.NoError(t, err)
readCh <- msg
}()
}
time.Sleep(50 * time.Millisecond)
s.ForEachConnection(func(c *websocket.Conn) {
err := c.WriteMessage(websocket.TextMessage, sentMsg)
assert.NoError(t, err)
})
wg.Wait()
close(readCh)
cnt := 0
for msg := range readCh {
assert.Equal(t, sentMsg, msg)
cnt++
}
require.Equal(t, readers, cnt)
err := s.Close()
require.NoError(t, err)
cnt = 0
s.connections.Range(func(_, _ any) bool {
cnt++
return true
})
require.Zero(t, cnt)
}

0 comments on commit d7df728

Please sign in to comment.