@@ -6,10 +6,12 @@ package websocket
66import (
77 "bufio"
88 "errors"
9+ "io"
910 "net"
1011 "net/http"
1112 "net/http/httptest"
1213 "strings"
14+ "sync"
1315 "testing"
1416
1517 "nhooyr.io/websocket/internal/test/assert"
@@ -142,6 +144,43 @@ func TestAccept(t *testing.T) {
142144 _ , err := Accept (w , r , nil )
143145 assert .Contains (t , err , `failed to hijack connection` )
144146 })
147+ t .Run ("closeRace" , func (t * testing.T ) {
148+ t .Parallel ()
149+
150+ server , _ := net .Pipe ()
151+
152+ pr , pw := io .Pipe ()
153+ rw := bufio .NewReadWriter (bufio .NewReader (pr ), bufio .NewWriter (pw ))
154+ newResponseWriter := func () http.ResponseWriter {
155+ return mockHijacker {
156+ ResponseWriter : httptest .NewRecorder (),
157+ hijack : func () (net.Conn , * bufio.ReadWriter , error ) {
158+ return server , rw , nil
159+ },
160+ }
161+ }
162+ w := newResponseWriter ()
163+
164+ r := httptest .NewRequest ("GET" , "/" , nil )
165+ r .Header .Set ("Connection" , "Upgrade" )
166+ r .Header .Set ("Upgrade" , "websocket" )
167+ r .Header .Set ("Sec-WebSocket-Version" , "13" )
168+ r .Header .Set ("Sec-WebSocket-Key" , xrand .Base64 (16 ))
169+
170+ c , err := Accept (w , r , nil )
171+ wg := & sync.WaitGroup {}
172+ wg .Add (2 )
173+ go func () {
174+ c .Close (StatusInternalError , "the sky is falling" )
175+ wg .Done ()
176+ }()
177+ go func () {
178+ c .CloseNow ()
179+ wg .Done ()
180+ }()
181+ wg .Wait ()
182+ assert .Success (t , err )
183+ })
145184}
146185
147186func Test_verifyClientHandshake (t * testing.T ) {
0 commit comments