11//go:build !js
22// +build !js
33
4- package websocket
4+ package websocket_test
55
66import (
77 "bytes"
@@ -10,12 +10,15 @@ import (
1010 "io"
1111 "net/http"
1212 "net/http/httptest"
13+ "net/url"
1314 "strings"
1415 "testing"
1516 "time"
1617
18+ "nhooyr.io/websocket"
1719 "nhooyr.io/websocket/internal/test/assert"
1820 "nhooyr.io/websocket/internal/util"
21+ "nhooyr.io/websocket/internal/xsync"
1922)
2023
2124func TestBadDials (t * testing.T ) {
@@ -27,7 +30,7 @@ func TestBadDials(t *testing.T) {
2730 testCases := []struct {
2831 name string
2932 url string
30- opts * DialOptions
33+ opts * websocket. DialOptions
3134 rand util.ReaderFunc
3235 nilCtx bool
3336 }{
@@ -72,7 +75,7 @@ func TestBadDials(t *testing.T) {
7275 tc .rand = rand .Reader .Read
7376 }
7477
75- _ , _ , err := dial (ctx , tc .url , tc .opts , tc .rand )
78+ _ , _ , err := websocket . ExportedDial (ctx , tc .url , tc .opts , tc .rand )
7679 assert .Error (t , err )
7780 })
7881 }
@@ -84,7 +87,7 @@ func TestBadDials(t *testing.T) {
8487 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
8588 defer cancel ()
8689
87- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
90+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
8891 HTTPClient : mockHTTPClient (func (* http.Request ) (* http.Response , error ) {
8992 return & http.Response {
9093 Body : io .NopCloser (strings .NewReader ("hi" )),
@@ -104,7 +107,7 @@ func TestBadDials(t *testing.T) {
104107 h := http.Header {}
105108 h .Set ("Connection" , "Upgrade" )
106109 h .Set ("Upgrade" , "websocket" )
107- h .Set ("Sec-WebSocket-Accept" , secWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
110+ h .Set ("Sec-WebSocket-Accept" , websocket . SecWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
108111
109112 return & http.Response {
110113 StatusCode : http .StatusSwitchingProtocols ,
@@ -113,7 +116,7 @@ func TestBadDials(t *testing.T) {
113116 }, nil
114117 }
115118
116- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
119+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
117120 HTTPClient : mockHTTPClient (rt ),
118121 })
119122 assert .Contains (t , err , "response body is not a io.ReadWriteCloser" )
@@ -152,7 +155,7 @@ func Test_verifyHostOverride(t *testing.T) {
152155 h := http.Header {}
153156 h .Set ("Connection" , "Upgrade" )
154157 h .Set ("Upgrade" , "websocket" )
155- h .Set ("Sec-WebSocket-Accept" , secWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
158+ h .Set ("Sec-WebSocket-Accept" , websocket . SecWebSocketAccept (r .Header .Get ("Sec-WebSocket-Key" )))
156159
157160 return & http.Response {
158161 StatusCode : http .StatusSwitchingProtocols ,
@@ -161,7 +164,7 @@ func Test_verifyHostOverride(t *testing.T) {
161164 }, nil
162165 }
163166
164- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
167+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
165168 HTTPClient : mockHTTPClient (rt ),
166169 Host : tc .host ,
167170 })
@@ -272,18 +275,18 @@ func Test_verifyServerHandshake(t *testing.T) {
272275 resp := w .Result ()
273276
274277 r := httptest .NewRequest ("GET" , "/" , nil )
275- key , err := secWebSocketKey (rand .Reader )
278+ key , err := websocket . SecWebSocketKey (rand .Reader )
276279 assert .Success (t , err )
277280 r .Header .Set ("Sec-WebSocket-Key" , key )
278281
279282 if resp .Header .Get ("Sec-WebSocket-Accept" ) == "" {
280- resp .Header .Set ("Sec-WebSocket-Accept" , secWebSocketAccept (key ))
283+ resp .Header .Set ("Sec-WebSocket-Accept" , websocket . SecWebSocketAccept (key ))
281284 }
282285
283- opts := & DialOptions {
286+ opts := & websocket. DialOptions {
284287 Subprotocols : strings .Split (r .Header .Get ("Sec-WebSocket-Protocol" ), "," ),
285288 }
286- _ , err = verifyServerResponse (opts , opts .CompressionMode . opts ( ), key , resp )
289+ _ , err = websocket . VerifyServerResponse (opts , websocket . CompressionModeOpts ( opts .CompressionMode ), key , resp )
287290 if tc .success {
288291 assert .Success (t , err )
289292 } else {
@@ -311,7 +314,7 @@ func TestDialRedirect(t *testing.T) {
311314 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
312315 defer cancel ()
313316
314- _ , _ , err := Dial (ctx , "ws://example.com" , & DialOptions {
317+ _ , _ , err := websocket . Dial (ctx , "ws://example.com" , & websocket. DialOptions {
315318 HTTPClient : mockHTTPClient (func (r * http.Request ) (* http.Response , error ) {
316319 resp := & http.Response {
317320 Header : http.Header {},
@@ -321,11 +324,88 @@ func TestDialRedirect(t *testing.T) {
321324 resp .StatusCode = http .StatusFound
322325 return resp , nil
323326 }
324- resp .Header .Set ("Connection" , "Upgrade" )
325- resp .Header .Set ("Upgrade" , "meow" )
327+ resp .Header .Set ("Connection" , "Upgrade" )
328+ resp .Header .Set ("Upgrade" , "meow" )
326329 resp .StatusCode = http .StatusSwitchingProtocols
327330 return resp , nil
328331 }),
329332 })
330333 assert .Contains (t , err , "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \" meow\" does not contain websocket" )
331334}
335+
336+ type forwardProxy struct {
337+ hc * http.Client
338+ }
339+
340+ func newForwardProxy () * forwardProxy {
341+ return & forwardProxy {
342+ hc : & http.Client {},
343+ }
344+ }
345+
346+ func (fc * forwardProxy ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
347+ ctx , cancel := context .WithTimeout (r .Context (), time .Second * 10 )
348+ defer cancel ()
349+
350+ r = r .WithContext (ctx )
351+ r .RequestURI = ""
352+ resp , err := fc .hc .Do (r )
353+ if err != nil {
354+ http .Error (w , err .Error (), http .StatusBadRequest )
355+ return
356+ }
357+ defer resp .Body .Close ()
358+
359+ for k , v := range resp .Header {
360+ w .Header ()[k ] = v
361+ }
362+ w .Header ().Set ("PROXIED" , "true" )
363+ w .WriteHeader (resp .StatusCode )
364+ errc1 := xsync .Go (func () error {
365+ _ , err := io .Copy (w , resp .Body )
366+ return err
367+ })
368+ var errc2 <- chan error
369+ if bodyw , ok := resp .Body .(io.Writer ); ok {
370+ errc2 = xsync .Go (func () error {
371+ _ , err := io .Copy (bodyw , r .Body )
372+ return err
373+ })
374+ }
375+ select {
376+ case <- errc1 :
377+ case <- errc2 :
378+ case <- r .Context ().Done ():
379+ }
380+ }
381+
382+ func TestDialViaProxy (t * testing.T ) {
383+ t .Parallel ()
384+
385+ ps := httptest .NewServer (newForwardProxy ())
386+ defer ps .Close ()
387+
388+ s := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
389+ err := echoServer (w , r , nil )
390+ assert .Success (t , err )
391+ }))
392+ defer s .Close ()
393+
394+ psu , err := url .Parse (ps .URL )
395+ assert .Success (t , err )
396+ proxyTransport := http .DefaultTransport .(* http.Transport ).Clone ()
397+ proxyTransport .Proxy = http .ProxyURL (psu )
398+
399+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 10 )
400+ defer cancel ()
401+ c , resp , err := websocket .Dial (ctx , s .URL , & websocket.DialOptions {
402+ HTTPClient : & http.Client {
403+ Transport : proxyTransport ,
404+ },
405+ })
406+ assert .Success (t , err )
407+ assert .Equal (t , "" , "true" , resp .Header .Get ("PROXIED" ))
408+
409+ assertEcho (t , ctx , c )
410+ assertClose (t , c )
411+ }
0 commit comments