@@ -82,7 +82,12 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
8282 return nil , nil , fmt .Errorf ("failed to generate Sec-WebSocket-Key: %w" , err )
8383 }
8484
85- resp , err := handshakeRequest (ctx , urls , opts , secWebSocketKey )
85+ var copts * compressionOptions
86+ if opts .CompressionMode != CompressionDisabled {
87+ copts = opts .CompressionMode .opts ()
88+ }
89+
90+ resp , err := handshakeRequest (ctx , urls , opts , copts , secWebSocketKey )
8691 if err != nil {
8792 return nil , resp , err
8893 }
@@ -104,7 +109,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
104109 }
105110 }()
106111
107- copts , err : = verifyServerResponse (opts , secWebSocketKey , resp )
112+ copts , err = verifyServerResponse (opts , copts , secWebSocketKey , resp )
108113 if err != nil {
109114 return nil , resp , err
110115 }
@@ -125,7 +130,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
125130 }), resp , nil
126131}
127132
128- func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , secWebSocketKey string ) (* http.Response , error ) {
133+ func handshakeRequest (ctx context.Context , urls string , opts * DialOptions , copts * compressionOptions , secWebSocketKey string ) (* http.Response , error ) {
129134 if opts .HTTPClient .Timeout > 0 {
130135 return nil , errors .New ("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67" )
131136 }
@@ -153,9 +158,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, secWe
153158 if len (opts .Subprotocols ) > 0 {
154159 req .Header .Set ("Sec-WebSocket-Protocol" , strings .Join (opts .Subprotocols , "," ))
155160 }
156- if opts .CompressionMode != CompressionDisabled {
157- copts := opts .CompressionMode .opts ()
158- copts .clientMaxWindowBits = 8
161+ if copts != nil {
159162 copts .setHeader (req .Header )
160163 }
161164
@@ -178,7 +181,7 @@ func secWebSocketKey(rr io.Reader) (string, error) {
178181 return base64 .StdEncoding .EncodeToString (b ), nil
179182}
180183
181- func verifyServerResponse (opts * DialOptions , secWebSocketKey string , resp * http.Response ) (* compressionOptions , error ) {
184+ func verifyServerResponse (opts * DialOptions , copts * compressionOptions , secWebSocketKey string , resp * http.Response ) (* compressionOptions , error ) {
182185 if resp .StatusCode != http .StatusSwitchingProtocols {
183186 return nil , fmt .Errorf ("expected handshake response status code %v but got %v" , http .StatusSwitchingProtocols , resp .StatusCode )
184187 }
@@ -203,7 +206,7 @@ func verifyServerResponse(opts *DialOptions, secWebSocketKey string, resp *http.
203206 return nil , err
204207 }
205208
206- return verifyServerExtensions (resp .Header )
209+ return verifyServerExtensions (copts , resp .Header )
207210}
208211
209212func verifySubprotocol (subprotos []string , resp * http.Response ) error {
@@ -221,19 +224,19 @@ func verifySubprotocol(subprotos []string, resp *http.Response) error {
221224 return fmt .Errorf ("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q" , proto )
222225}
223226
224- func verifyServerExtensions (h http.Header ) (* compressionOptions , error ) {
227+ func verifyServerExtensions (copts * compressionOptions , h http.Header ) (* compressionOptions , error ) {
225228 exts := websocketExtensions (h )
226229 if len (exts ) == 0 {
227230 return nil , nil
228231 }
229232
230233 ext := exts [0 ]
231- if ext .name != "permessage-deflate" || len (exts ) > 1 {
234+ if ext .name != "permessage-deflate" || len (exts ) > 1 || copts == nil {
232235 return nil , fmt .Errorf ("WebSocket protcol violation: unsupported extensions from server: %+v" , exts [1 :])
233236 }
234237
235- copts : = & compressionOptions {}
236- copts . clientMaxWindowBits = 8
238+ copts = & * copts
239+
237240 for _ , p := range ext .params {
238241 switch p {
239242 case "client_no_context_takeover" :
@@ -244,24 +247,6 @@ func verifyServerExtensions(h http.Header) (*compressionOptions, error) {
244247 continue
245248 }
246249
247- if false && strings .HasPrefix (p , "server_max_window_bits" ) {
248- bits , ok := parseExtensionParameter (p , 0 )
249- if ! ok || bits < 8 || bits > 16 {
250- return nil , fmt .Errorf ("invalid server_max_window_bits: %q" , p )
251- }
252- copts .serverMaxWindowBits = bits
253- continue
254- }
255-
256- if false && strings .HasPrefix (p , "client_max_window_bits" ) {
257- bits , ok := parseExtensionParameter (p , 0 )
258- if ! ok || bits < 8 || bits > 16 {
259- return nil , fmt .Errorf ("invalid client_max_window_bits: %q" , p )
260- }
261- copts .clientMaxWindowBits = 8
262- continue
263- }
264-
265250 return nil , fmt .Errorf ("unsupported permessage-deflate parameter: %q" , p )
266251 }
267252
0 commit comments