@@ -59,13 +59,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
5959 return err
6060 }
6161
62- if ! headerValuesContainsToken (r .Header , "Connection" , "Upgrade" ) {
62+ if ! headerContainsToken (r .Header , "Connection" , "Upgrade" ) {
6363 err := fmt .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
6464 http .Error (w , err .Error (), http .StatusBadRequest )
6565 return err
6666 }
6767
68- if ! headerValuesContainsToken (r .Header , "Upgrade" , "WebSocket" ) {
68+ if ! headerContainsToken (r .Header , "Upgrade" , "WebSocket" ) {
6969 err := fmt .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
7070 http .Error (w , err .Error (), http .StatusBadRequest )
7171 return err
@@ -144,6 +144,18 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
144144 w .Header ().Set ("Sec-WebSocket-Protocol" , subproto )
145145 }
146146
147+ var copts * CompressionOptions
148+ if opts .Compression != nil {
149+ copts , err = negotiateCompression (r .Header , opts .Compression )
150+ if err != nil {
151+ http .Error (w , err .Error (), http .StatusBadRequest )
152+ return nil , err
153+ }
154+ if copts != nil {
155+ copts .setHeader (w .Header ())
156+ }
157+ }
158+
147159 w .WriteHeader (http .StatusSwitchingProtocols )
148160
149161 netConn , brw , err := hj .Hijack ()
@@ -162,40 +174,65 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
162174 br : brw .Reader ,
163175 bw : brw .Writer ,
164176 closer : netConn ,
177+ copts : copts ,
165178 }
166179 c .init ()
167180
168181 return c , nil
169182}
170183
171- func headerValuesContainsToken (h http.Header , key , token string ) bool {
184+ func headerContainsToken (h http.Header , key , token string ) bool {
172185 key = textproto .CanonicalMIMEHeaderKey (key )
173186
174- for _ , val2 := range h [key ] {
175- if headerValueContainsToken (val2 , token ) {
187+ token = strings .ToLower (token )
188+ match := func (t string ) bool {
189+ return t == token
190+ }
191+
192+ for _ , v := range h [key ] {
193+ if searchHeaderTokens (v , match ) != "" {
176194 return true
177195 }
178196 }
179197
180198 return false
181199}
182200
183- func headerValueContainsToken ( val2 , token string ) bool {
184- val2 = strings . TrimSpace ( val2 )
201+ func headerTokenHasPrefix ( h http. Header , key , prefix string ) string {
202+ key = textproto . CanonicalMIMEHeaderKey ( key )
185203
186- for _ , val2 := range strings .Split (val2 , "," ) {
187- val2 = strings .TrimSpace (val2 )
188- if strings .EqualFold (val2 , token ) {
189- return true
204+ prefix = strings .ToLower (prefix )
205+ match := func (t string ) bool {
206+ return strings .HasPrefix (t , prefix )
207+ }
208+
209+ for _ , v := range h [key ] {
210+ found := searchHeaderTokens (v , match )
211+ if found != "" {
212+ return found
190213 }
191214 }
192215
193- return false
216+ return ""
217+ }
218+
219+ func searchHeaderTokens (v string , match func (val string ) bool ) string {
220+ v = strings .TrimSpace (v )
221+
222+ for _ , v2 := range strings .Split (v , "," ) {
223+ v2 = strings .TrimSpace (v2 )
224+ v2 = strings .ToLower (v2 )
225+ if match (v2 ) {
226+ return v2
227+ }
228+ }
229+
230+ return ""
194231}
195232
196233func selectSubprotocol (r * http.Request , subprotocols []string ) string {
197234 for _ , sp := range subprotocols {
198- if headerValuesContainsToken (r .Header , "Sec-WebSocket-Protocol" , sp ) {
235+ if headerContainsToken (r .Header , "Sec-WebSocket-Protocol" , sp ) {
199236 return sp
200237 }
201238 }
@@ -268,36 +305,32 @@ type DialOptions struct {
268305//
269306// See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression.
270307//
271- // Enabling compression will increase memory and CPU usage.
272- // Thus it is not ideal for every use case and disabled by default .
308+ // Enabling compression will increase memory and CPU usage and should
309+ // be profiled before enabling in production .
273310// See https://github.com/gorilla/websocket/issues/203
274- // Profile before enabling in production.
275311//
276312// This API is experimental and subject to change.
277313type CompressionOptions struct {
278- // ServerNoContextTakeover controls whether the server should use context takeover.
279- // See docs on CompressionOptions for discussion regarding context takeover.
280- //
281- // If set by the client, will guarantee that the server does not use context takeover.
282- ServerNoContextTakeover bool
283-
284314 // ClientNoContextTakeover controls whether the client should use context takeover.
285315 // See docs on CompressionOptions for discussion regarding context takeover.
286316 //
287317 // If set by the server, will guarantee that the client does not use context takeover.
288318 ClientNoContextTakeover bool
289319
320+ // ServerNoContextTakeover controls whether the server should use context takeover.
321+ // See docs on CompressionOptions for discussion regarding context takeover.
322+ //
323+ // If set by the client, will guarantee that the server does not use context takeover.
324+ ServerNoContextTakeover bool
325+
290326 // Level controls the compression level used.
291327 // Defaults to flate.BestSpeed.
292328 Level int
293329
294330 // Threshold controls the minimum message size in bytes before compression is used.
295- // In the case of ContextTakeover == false, a flate.Writer will not be grabbed
296- // from the pool until the message exceeds this threshold.
297- //
298331 // Must not be greater than 4096 as that is the write buffer's size.
299332 //
300- // Defaults to 512 .
333+ // Defaults to 256 .
301334 Threshold int
302335}
303336
@@ -319,25 +352,32 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon
319352 return c , r , nil
320353}
321354
322- func dial ( ctx context. Context , u string , opts * DialOptions ) ( _ * Conn , _ * http. Response , err error ) {
355+ func ( opts * DialOptions ) ensure () ( * DialOptions , error ) {
323356 if opts == nil {
324357 opts = & DialOptions {}
358+ } else {
359+ opts = & * opts
325360 }
326361
327- // Shallow copy to ensure defaults do not affect user passed options.
328- opts2 := * opts
329- opts = & opts2
330-
331362 if opts .HTTPClient == nil {
332363 opts .HTTPClient = http .DefaultClient
333364 }
334365 if opts .HTTPClient .Timeout > 0 {
335- return nil , nil , fmt .Errorf ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
366+ return nil , fmt .Errorf ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
336367 }
337368 if opts .HTTPHeader == nil {
338369 opts .HTTPHeader = http.Header {}
339370 }
340371
372+ return opts , nil
373+ }
374+
375+ func dial (ctx context.Context , u string , opts * DialOptions ) (_ * Conn , _ * http.Response , err error ) {
376+ opts , err = opts .ensure ()
377+ if err != nil {
378+ return nil , nil , err
379+ }
380+
341381 parsedURL , err := url .Parse (u )
342382 if err != nil {
343383 return nil , nil , fmt .Errorf ("failed to parse url: %w" , err )
@@ -367,7 +407,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
367407 req .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (opts .Subprotocols , "," ))
368408 }
369409 if opts .Compression != nil {
370- req . Header . Set ( "Sec-WebSocket-Extensions" , "permessage-deflate; server_no_context_takeover; client_no_context_takeover" )
410+ opts . Compression . setHeader ( req . Header )
371411 }
372412
373413 resp , err := opts .HTTPClient .Do (req )
@@ -384,7 +424,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
384424 }
385425 }()
386426
387- err = verifyServerResponse (req , resp )
427+ copts , err : = verifyServerResponse (req , resp , opts )
388428 if err != nil {
389429 return nil , resp , err
390430 }
@@ -400,38 +440,48 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re
400440 bw : getBufioWriter (rwc ),
401441 closer : rwc ,
402442 client : true ,
443+ copts : copts ,
403444 }
404445 c .extractBufioWriterBuf (rwc )
405446 c .init ()
406447
407448 return c , resp , nil
408449}
409450
410- func verifyServerResponse (r * http.Request , resp * http.Response ) error {
451+ func verifyServerResponse (r * http.Request , resp * http.Response , opts * DialOptions ) ( * CompressionOptions , error ) {
411452 if resp .StatusCode != http .StatusSwitchingProtocols {
412- return fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
453+ return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
413454 }
414455
415- if ! headerValuesContainsToken (resp .Header , "Connection" , "Upgrade" ) {
416- return fmt .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
456+ if ! headerContainsToken (resp .Header , "Connection" , "Upgrade" ) {
457+ return nil , fmt .Errorf ("websocket protocol violation: Connection header %q does not contain Upgrade" , resp .Header .Get ("Connection" ))
417458 }
418459
419- if ! headerValuesContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
420- return fmt .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
460+ if ! headerContainsToken (resp .Header , "Upgrade" , "WebSocket" ) {
461+ return nil , fmt .Errorf ("websocket protocol violation: Upgrade header %q does not contain websocket" , resp .Header .Get ("Upgrade" ))
421462 }
422463
423464 if resp .Header .Get ("Sec-WebSocket-Accept" ) != secWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )) {
424- return fmt .Errorf ("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
465+ return nil , fmt .Errorf ("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q" ,
425466 resp .Header .Get ("Sec-WebSocket-Accept" ),
426467 r .Header .Get ("Sec-WebSocket-Key" ),
427468 )
428469 }
429470
430- if proto := resp .Header .Get ("Sec-WebSocket-Protocol" ); proto != "" && ! headerValuesContainsToken (r .Header , "Sec-WebSocket-Protocol" , proto ) {
431- return fmt .Errorf ("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
471+ if proto := resp .Header .Get ("Sec-WebSocket-Protocol" ); proto != "" && ! headerContainsToken (r .Header , "Sec-WebSocket-Protocol" , proto ) {
472+ return nil , fmt .Errorf ("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
432473 }
433474
434- return nil
475+ var copts * CompressionOptions
476+ if opts .Compression != nil {
477+ var err error
478+ copts , err = negotiateCompression (resp .Header , opts .Compression )
479+ if err != nil {
480+ return nil , err
481+ }
482+ }
483+
484+ return copts , nil
435485}
436486
437487// The below pools can only be used by the client because http.Hijacker will always
@@ -477,3 +527,55 @@ func makeSecWebSocketKey() (string, error) {
477527 }
478528 return base64 .StdEncoding .EncodeToString (b ), nil
479529}
530+
531+ func negotiateCompression (h http.Header , copts * CompressionOptions ) (* CompressionOptions , error ) {
532+ deflate := headerTokenHasPrefix (h , "Sec-WebSocket-Extensions" , "permessage-deflate" )
533+ if deflate == "" {
534+ return nil , nil
535+ }
536+
537+ // Ensures our changes do not modify the real compression options.
538+ copts = & * copts
539+
540+ params := strings .Split (deflate , ";" )
541+ for i := range params {
542+ params [i ] = strings .TrimSpace (params [i ])
543+ }
544+
545+ if params [0 ] != "permessage-deflate" {
546+ return nil , fmt .Errorf ("unexpected header format for permessage-deflate extension: %q" , deflate )
547+ }
548+
549+ for _ , p := range params [1 :] {
550+ switch p {
551+ case "client_no_context_takeover" :
552+ copts .ClientNoContextTakeover = true
553+ continue
554+ case "server_no_context_takeover" :
555+ copts .ServerNoContextTakeover = true
556+ continue
557+ case "client_max_window_bits" , "server-max-window-bits" :
558+ server := h .Get ("Sec-WebSocket-Key" ) != ""
559+ if server {
560+ // If we are the server, we are allowed to ignore these parameters.
561+ // However, if we are the client, we must obey them but because of
562+ // https://github.com/golang/go/issues/3155 we cannot.
563+ continue
564+ }
565+ }
566+ return nil , fmt .Errorf ("unsupported permessage-deflate parameter %q in header: %q" , p , deflate )
567+ }
568+
569+ return copts , nil
570+ }
571+
572+ func (copts * CompressionOptions ) setHeader (h http.Header ) {
573+ s := "permessage-deflate"
574+ if copts .ClientNoContextTakeover {
575+ s += "; client_no_context_takeover"
576+ }
577+ if copts .ServerNoContextTakeover {
578+ s += "; server_no_context_takeover"
579+ }
580+ h .Set ("Sec-WebSocket-Extensions" , s )
581+ }
0 commit comments