@@ -65,9 +65,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
6565 opts .CompressionOptions = & CompressionOptions {}
6666 }
6767
68- err = verifyClientRequest (r )
68+ errCode , err : = verifyClientRequest (w , r )
6969 if err != nil {
70- http .Error (w , err .Error (), http . StatusBadRequest )
70+ http .Error (w , err .Error (), errCode )
7171 return nil , err
7272 }
7373
@@ -127,32 +127,37 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
127127 }), nil
128128}
129129
130- func verifyClientRequest (r * http.Request ) error {
130+ func verifyClientRequest (w http. ResponseWriter , r * http.Request ) ( errCode int , _ error ) {
131131 if ! r .ProtoAtLeast (1 , 1 ) {
132- return xerrors .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
132+ return http . StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q" , r .Proto )
133133 }
134134
135135 if ! headerContainsToken (r .Header , "Connection" , "Upgrade" ) {
136- return xerrors .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
136+ w .Header ().Set ("Connection" , "Upgrade" )
137+ w .Header ().Set ("Upgrade" , "websocket" )
138+ return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: Connection header %q does not contain Upgrade" , r .Header .Get ("Connection" ))
137139 }
138140
139141 if ! headerContainsToken (r .Header , "Upgrade" , "websocket" ) {
140- return xerrors .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
142+ w .Header ().Set ("Connection" , "Upgrade" )
143+ w .Header ().Set ("Upgrade" , "websocket" )
144+ return http .StatusUpgradeRequired , xerrors .Errorf ("WebSocket protocol violation: Upgrade header %q does not contain websocket" , r .Header .Get ("Upgrade" ))
141145 }
142146
143147 if r .Method != "GET" {
144- return xerrors .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
148+ return http . StatusMethodNotAllowed , xerrors .Errorf ("WebSocket protocol violation: handshake request method is not GET but %q" , r .Method )
145149 }
146150
147151 if r .Header .Get ("Sec-WebSocket-Version" ) != "13" {
148- return xerrors .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
152+ w .Header ().Set ("Sec-WebSocket-Version" , "13" )
153+ return http .StatusBadRequest , xerrors .Errorf ("unsupported WebSocket protocol version (only 13 is supported): %q" , r .Header .Get ("Sec-WebSocket-Version" ))
149154 }
150155
151156 if r .Header .Get ("Sec-WebSocket-Key" ) == "" {
152- return xerrors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
157+ return http . StatusBadRequest , xerrors .New ("WebSocket protocol violation: missing Sec-WebSocket-Key" )
153158 }
154159
155- return nil
160+ return 0 , nil
156161}
157162
158163func authenticateOrigin (r * http.Request ) error {
0 commit comments