@@ -78,11 +78,10 @@ type Conn struct {
7878 readLock chan struct {}
7979
8080 // messageReader state.
81- readerMsgCtx context.Context
82- readerMsgHeader header
83- readerFrameEOF bool
84- readerMaskPos int
85- readerShouldLock bool
81+ readerMsgCtx context.Context
82+ readerMsgHeader header
83+ readerFrameEOF bool
84+ readerMaskPos int
8685
8786 setReadTimeout chan context.Context
8887 setWriteTimeout chan context.Context
@@ -237,6 +236,10 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
237236 if h .opcode .controlOp () {
238237 err = c .handleControl (ctx , h )
239238 if err != nil {
239+ // Pass through CloseErrors when receiving a close frame.
240+ if h .opcode == opClose && CloseStatus (err ) != - 1 {
241+ return header {}, err
242+ }
240243 return header {}, fmt .Errorf ("failed to handle control frame %v: %w" , h .opcode , err )
241244 }
242245 continue
@@ -445,7 +448,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
445448 c .readerFrameEOF = false
446449 c .readerMaskPos = 0
447450 c .readMsgLeft = c .msgReadLimit .Load ()
448- c .readerShouldLock = lock
449451
450452 r := & messageReader {
451453 c : c ,
@@ -465,7 +467,11 @@ func (r *messageReader) eof() bool {
465467
466468// Read reads as many bytes as possible into p.
467469func (r * messageReader ) Read (p []byte ) (int , error ) {
468- n , err := r .read (p )
470+ return r .exportedRead (p , true )
471+ }
472+
473+ func (r * messageReader ) exportedRead (p []byte , lock bool ) (int , error ) {
474+ n , err := r .read (p , lock )
469475 if err != nil {
470476 // Have to return io.EOF directly for now, we cannot wrap as errors.Is
471477 // isn't used widely yet.
@@ -477,17 +483,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
477483 return n , nil
478484}
479485
480- func (r * messageReader ) read (p []byte ) (int , error ) {
481- if r .c .readerShouldLock {
482- err := r .c .acquireLock (r .c .readerMsgCtx , r .c .readLock )
483- if err != nil {
484- return 0 , err
486+ func (r * messageReader ) readUnlocked (p []byte ) (int , error ) {
487+ return r .exportedRead (p , false )
488+ }
489+
490+ func (r * messageReader ) read (p []byte , lock bool ) (int , error ) {
491+ if lock {
492+ // If we cannot acquire the read lock, then
493+ // there is either a concurrent read or the close handshake
494+ // is proceeding.
495+ select {
496+ case r .c .readLock <- struct {}{}:
497+ defer r .c .releaseLock (r .c .readLock )
498+ default :
499+ if r .c .closing .Load () == 1 {
500+ <- r .c .closed
501+ return 0 , r .c .closeErr
502+ }
503+ return 0 , errors .New ("concurrent read detected" )
485504 }
486- defer r .c .releaseLock (r .c .readLock )
487505 }
488506
489507 if r .eof () {
490- return 0 , fmt . Errorf ("cannot use EOFed reader" )
508+ return 0 , errors . New ("cannot use EOFed reader" )
491509 }
492510
493511 if r .c .readMsgLeft <= 0 {
@@ -950,8 +968,6 @@ func (c *Conn) waitClose() error {
950968 return c .closeReceived
951969 }
952970
953- c .readerShouldLock = false
954-
955971 b := bpool .Get ()
956972 buf := b .Bytes ()
957973 buf = buf [:cap (buf )]
@@ -965,7 +981,8 @@ func (c *Conn) waitClose() error {
965981 }
966982 }
967983
968- _ , err = io .CopyBuffer (ioutil .Discard , c .activeReader , buf )
984+ r := readerFunc (c .activeReader .readUnlocked )
985+ _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
969986 if err != nil {
970987 return err
971988 }
@@ -1019,6 +1036,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
10191036 }
10201037}
10211038
1039+ type readerFunc func (p []byte ) (int , error )
1040+
1041+ func (f readerFunc ) Read (p []byte ) (int , error ) {
1042+ return f (p )
1043+ }
1044+
10221045type writerFunc func (p []byte ) (int , error )
10231046
10241047func (f writerFunc ) Write (p []byte ) (int , error ) {
0 commit comments