44 "log"
55 "net/http"
66 "net/url"
7- "os"
87 "strings"
98 "sync"
109
@@ -15,6 +14,9 @@ type Hub struct {
1514 mu sync.RWMutex
1615 rooms map [string ]map [string ]* Client
1716 upg websocket.Upgrader
17+
18+ allowedOrigins []string
19+ allowAllOrigins bool
1820}
1921
2022type Client struct {
@@ -24,31 +26,31 @@ type Client struct {
2426 send chan Message
2527}
2628
27- var (
28- allowedOrigins []string
29- allowAllOrigins bool
30- )
29+ type Options struct {
30+ AllowedOrigins []string
31+ AllowAllOrigins bool
32+ }
3133
32- func init () {
33- v := os .Getenv ("WS_ALLOWED_ORIGINS" )
34- if v == "" {
35- return
36- }
37- if v == "*" {
38- allowAllOrigins = true
39- return
40- }
41- parts := strings .Split (v , "," )
42- for _ , p := range parts {
43- p = strings .TrimSpace (p )
44- if p != "" {
45- allowedOrigins = append (allowedOrigins , p )
46- }
34+ func NewHub () * Hub {
35+ return NewHubWithOptions (Options {})
36+ }
37+
38+ func NewHubWithOptions (opts Options ) * Hub {
39+ h := & Hub {
40+ rooms : make (map [string ]map [string ]* Client ),
41+ allowedOrigins : append ([]string (nil ), opts .AllowedOrigins ... ),
42+ allowAllOrigins : opts .AllowAllOrigins ,
43+ }
44+ h .upg = websocket.Upgrader {
45+ CheckOrigin : func (r * http.Request ) bool {
46+ return h .isOriginAllowed (r )
47+ },
4748 }
49+ return h
4850}
4951
50- func isOriginAllowed (r * http.Request ) bool {
51- if allowAllOrigins {
52+ func ( h * Hub ) isOriginAllowed (r * http.Request ) bool {
53+ if h . allowAllOrigins {
5254 return true
5355 }
5456 origin := r .Header .Get ("Origin" )
@@ -59,7 +61,7 @@ func isOriginAllowed(r *http.Request) bool {
5961 }
6062 return false
6163 }
62- if len (allowedOrigins ) == 0 {
64+ if len (h . allowedOrigins ) == 0 {
6365 u , err := url .Parse (origin )
6466 if err != nil {
6567 return false
@@ -70,36 +72,26 @@ func isOriginAllowed(r *http.Request) bool {
7072 }
7173 return false
7274 }
73- for _ , o := range allowedOrigins {
75+ for _ , o := range h . allowedOrigins {
7476 if o == origin {
7577 return true
7678 }
7779 }
7880 return false
7981}
8082
81- func NewHub () * Hub {
82- return & Hub {
83- rooms : make (map [string ]map [string ]* Client ),
84- upg : websocket.Upgrader {
85- CheckOrigin : func (r * http.Request ) bool {
86- return isOriginAllowed (r )
87- },
88- },
89- }
90- }
91-
9283func (h * Hub ) HandleWS (w http.ResponseWriter , r * http.Request ) {
9384 c , err := h .upg .Upgrade (w , r , nil )
9485 if err != nil {
95- log .Printf ("signal: ws upgrade failed from %s: %v" , r .RemoteAddr , err )
86+ log .Printf ("signal: ws upgrade failed from %s path=%s : %v" , r .RemoteAddr , r . URL . Path , err )
9687 return
9788 }
9889 log .Printf ("signal: ws connected from %s" , r .RemoteAddr )
9990 client := & Client {conn : c , send : make (chan Message , 32 )}
10091 go h .writePump (client )
10192 defer func () {
10293 h .removeClient (client )
94+ close (client .send )
10395 c .Close ()
10496 }()
10597 for {
@@ -162,30 +154,34 @@ func (h *Hub) removeClient(c *Client) {
162154 if c .room == "" || c .id == "" {
163155 return
164156 }
165- if m , ok := h .rooms [c .room ]; ok {
166- if existing , ok2 := m [c .id ]; ok2 {
167- delete (m , c .id )
168- close (existing .send )
157+ room := c .room
158+ if m , ok := h .rooms [room ]; ok {
159+ if _ , ok2 := m [c .id ]; ! ok2 {
160+ c .room = ""
161+ return
169162 }
163+ delete (m , c .id )
164+ c .room = ""
170165 if len (m ) == 0 {
171- delete (h .rooms , c .room )
172- log .Printf ("signal: room %s closed" , c .room )
173- } else {
174- members := make ([]string , 0 , len (m ))
175- for id := range m {
176- members = append (members , id )
177- }
178- msg := Message {
179- Type : "room_members" ,
180- Room : c .room ,
181- Members : members ,
182- }
183- for _ , cli := range m {
184- if cli != nil && cli .conn != nil {
185- select {
186- case cli .send <- msg :
187- default :
188- }
166+ delete (h .rooms , room )
167+ log .Printf ("signal: room %s closed" , room )
168+ return
169+ }
170+
171+ members := make ([]string , 0 , len (m ))
172+ for id := range m {
173+ members = append (members , id )
174+ }
175+ msg := Message {
176+ Type : "room_members" ,
177+ Room : room ,
178+ Members : members ,
179+ }
180+ for _ , cli := range m {
181+ if cli != nil && cli .conn != nil {
182+ select {
183+ case cli .send <- msg :
184+ default :
189185 }
190186 }
191187 }
@@ -210,6 +206,7 @@ func (h *Hub) writePump(c *Client) {
210206 for msg := range c .send {
211207 if err := c .conn .WriteJSON (msg ); err != nil {
212208 log .Printf ("signal: write message error room=%s id=%s: %v" , c .room , c .id , err )
209+ c .conn .Close ()
213210 break
214211 }
215212 }
0 commit comments