@@ -9,10 +9,11 @@ import (
99 "errors"
1010 "fmt"
1111 "io"
12+ "log"
1213 "net/http"
1314 "net/textproto"
1415 "net/url"
15- "strconv "
16+ "path/filepath "
1617 "strings"
1718
1819 "nhooyr.io/websocket/internal/errd"
@@ -25,18 +26,27 @@ type AcceptOptions struct {
2526 // reject it, close the connection when c.Subprotocol() == "".
2627 Subprotocols []string
2728
28- // InsecureSkipVerify disables Accept's origin verification behaviour. By default,
29- // the connection will only be accepted if the request origin is equal to the request
30- // host.
29+ // InsecureSkipVerify is used to disable Accept's origin verification behaviour.
3130 //
32- // This is only required if you want javascript served from a different domain
33- // to access your WebSocket server.
31+ // Deprecated: Use OriginPatterns with a match all pattern of * instead to control
32+ // origin authorization yourself.
33+ InsecureSkipVerify bool
34+
35+ // OriginPatterns lists the host patterns for authorized origins.
36+ // The request host is always authorized.
37+ // Use this to enable cross origin WebSockets.
38+ //
39+ // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
40+ // In such a case, example.com is the origin and chat.example.com is the request host.
41+ // One would set this field to []string{"example.com"} to authorize example.com to connect.
3442 //
35- // See https://stackoverflow.com/a/37837709/4283659
43+ // Each pattern is matched case insensitively against the request origin host
44+ // with filepath.Match.
45+ // See https://golang.org/pkg/path/filepath/#Match
3646 //
3747 // Please ensure you understand the ramifications of enabling this.
3848 // If used incorrectly your WebSocket server will be open to CSRF attacks.
39- InsecureSkipVerify bool
49+ OriginPatterns [] string
4050
4151 // CompressionMode controls the compression mode.
4252 // Defaults to CompressionNoContextTakeover.
@@ -77,8 +87,12 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
7787 }
7888
7989 if ! opts .InsecureSkipVerify {
80- err = authenticateOrigin (r )
90+ err = authenticateOrigin (r , opts . OriginPatterns )
8191 if err != nil {
92+ if errors .Is (err , filepath .ErrBadPattern ) {
93+ log .Printf ("websocket: %v" , err )
94+ err = errors .New (http .StatusText (http .StatusForbidden ))
95+ }
8296 http .Error (w , err .Error (), http .StatusForbidden )
8397 return nil , err
8498 }
@@ -165,18 +179,35 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
165179 return 0 , nil
166180}
167181
168- func authenticateOrigin (r * http.Request ) error {
182+ func authenticateOrigin (r * http.Request , originHosts [] string ) error {
169183 origin := r .Header .Get ("Origin" )
170- if origin != "" {
171- u , err := url .Parse (origin )
184+ if origin == "" {
185+ return nil
186+ }
187+
188+ u , err := url .Parse (origin )
189+ if err != nil {
190+ return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
191+ }
192+
193+ if strings .EqualFold (r .Host , u .Host ) {
194+ return nil
195+ }
196+
197+ for _ , hostPattern := range originHosts {
198+ matched , err := match (hostPattern , u .Host )
172199 if err != nil {
173- return fmt .Errorf ("failed to parse Origin header %q: %w" , origin , err )
200+ return fmt .Errorf ("failed to parse filepath pattern %q: %w" , hostPattern , err )
174201 }
175- if ! strings . EqualFold ( u . Host , r . Host ) {
176- return fmt . Errorf ( "request Origin %q is not authorized for Host %q" , origin , r . Host )
202+ if matched {
203+ return nil
177204 }
178205 }
179- return nil
206+ return fmt .Errorf ("request Origin %q is not authorized for Host %q" , origin , r .Host )
207+ }
208+
209+ func match (pattern , s string ) (bool , error ) {
210+ return filepath .Match (strings .ToLower (pattern ), strings .ToLower (s ))
180211}
181212
182213func selectSubprotocol (r * http.Request , subprotocols []string ) string {
@@ -235,16 +266,6 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
235266 return copts , nil
236267}
237268
238- // parseExtensionParameter parses the value in the extension parameter p.
239- func parseExtensionParameter (p string ) (int , bool ) {
240- ps := strings .Split (p , "=" )
241- if len (ps ) == 1 {
242- return 0 , false
243- }
244- i , e := strconv .Atoi (strings .Trim (ps [1 ], `"` ))
245- return i , e == nil
246- }
247-
248269func acceptWebkitDeflate (w http.ResponseWriter , ext websocketExtension , mode CompressionMode ) (* compressionOptions , error ) {
249270 copts := mode .opts ()
250271 // The peer must explicitly request it.
0 commit comments