@@ -5,6 +5,8 @@ package websocket_test
55import (
66 "context"
77 "fmt"
8+ "io"
9+ "io/ioutil"
810 "net/http"
911 "net/http/httptest"
1012 "os"
@@ -200,6 +202,121 @@ func TestConn(t *testing.T) {
200202 t .Fatalf ("unexpected error: %v" , err )
201203 }
202204 })
205+
206+ t .Run ("concurrentWriteError" , func (t * testing.T ) {
207+ t .Parallel ()
208+
209+ c1 , c2 , err := wstest .Pipe (nil , nil )
210+ if err != nil {
211+ t .Fatal (err )
212+ }
213+ defer c2 .Close (websocket .StatusInternalError , "" )
214+ defer c1 .Close (websocket .StatusInternalError , "" )
215+
216+ _ , err = c1 .Writer (context .Background (), websocket .MessageText )
217+ if err != nil {
218+ t .Fatal (err )
219+ }
220+
221+ ctx , cancel := context .WithTimeout (context .Background (), time .Millisecond * 100 )
222+ defer cancel ()
223+
224+ err = c1 .Write (ctx , websocket .MessageText , []byte ("x" ))
225+ if ! xerrors .Is (err , context .DeadlineExceeded ) {
226+ t .Fatal (err )
227+ }
228+ })
229+
230+ t .Run ("netConn" , func (t * testing.T ) {
231+ t .Parallel ()
232+
233+ c1 , c2 , err := wstest .Pipe (nil , nil )
234+ if err != nil {
235+ t .Fatal (err )
236+ }
237+ defer c2 .Close (websocket .StatusInternalError , "" )
238+ defer c1 .Close (websocket .StatusInternalError , "" )
239+
240+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
241+ defer cancel ()
242+
243+ n1 := websocket .NetConn (ctx , c1 , websocket .MessageBinary )
244+ n2 := websocket .NetConn (ctx , c2 , websocket .MessageBinary )
245+
246+ // Does not give any confidence but at least ensures no crashes.
247+ d , _ := ctx .Deadline ()
248+ n1 .SetDeadline (d )
249+ n1 .SetDeadline (time.Time {})
250+
251+ if n1 .RemoteAddr () != n1 .LocalAddr () {
252+ t .Fatal ()
253+ }
254+ if n1 .RemoteAddr ().String () != "websocket/unknown-addr" || n1 .RemoteAddr ().Network () != "websocket" {
255+ t .Fatal (n1 .RemoteAddr ())
256+ }
257+
258+ errs := xsync .Go (func () error {
259+ _ , err := n2 .Write ([]byte ("hello" ))
260+ if err != nil {
261+ return err
262+ }
263+ return n2 .Close ()
264+ })
265+
266+ b , err := ioutil .ReadAll (n1 )
267+ if err != nil {
268+ t .Fatal (err )
269+ }
270+
271+ _ , err = n1 .Read (nil )
272+ if err != io .EOF {
273+ t .Fatalf ("expected EOF: %v" , err )
274+ }
275+
276+ err = <- errs
277+ if err != nil {
278+ t .Fatal (err )
279+ }
280+
281+ if ! cmp .Equal ([]byte ("hello" ), b ) {
282+ t .Fatalf ("unexpected msg: %v" , cmp .Diff ([]byte ("hello" ), b ))
283+ }
284+ })
285+
286+ t .Run ("netConn" , func (t * testing.T ) {
287+ t .Parallel ()
288+
289+ c1 , c2 , err := wstest .Pipe (nil , nil )
290+ if err != nil {
291+ t .Fatal (err )
292+ }
293+ defer c2 .Close (websocket .StatusInternalError , "" )
294+ defer c1 .Close (websocket .StatusInternalError , "" )
295+
296+ ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
297+ defer cancel ()
298+
299+ n1 := websocket .NetConn (ctx , c1 , websocket .MessageBinary )
300+ n2 := websocket .NetConn (ctx , c2 , websocket .MessageText )
301+
302+ errs := xsync .Go (func () error {
303+ _ , err := n2 .Write ([]byte ("hello" ))
304+ if err != nil {
305+ return err
306+ }
307+ return nil
308+ })
309+
310+ _ , err = ioutil .ReadAll (n1 )
311+ if ! cmp .ErrorContains (err , `unexpected frame type read (expected MessageBinary): MessageText` ) {
312+ t .Fatal (err )
313+ }
314+
315+ err = <- errs
316+ if err != nil {
317+ t .Fatal (err )
318+ }
319+ })
203320}
204321
205322func TestWasm (t * testing.T ) {
0 commit comments