@@ -78,11 +78,10 @@ type Conn struct {
7878 readLock chan struct {}
7979
8080 // messageReader state.
81- readerMsgCtx context.Context
82- readerMsgHeader header
83- readerFrameEOF bool
84- readerMaskPos int
85- readerShouldLock bool
81+ readerMsgCtx context.Context
82+ readerMsgHeader header
83+ readerFrameEOF bool
84+ readerMaskPos int
8685
8786 setReadTimeout chan context.Context
8887 setWriteTimeout chan context.Context
@@ -445,7 +444,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
445444 c .readerFrameEOF = false
446445 c .readerMaskPos = 0
447446 c .readMsgLeft = c .msgReadLimit .Load ()
448- c .readerShouldLock = lock
449447
450448 r := & messageReader {
451449 c : c ,
@@ -465,7 +463,11 @@ func (r *messageReader) eof() bool {
465463
466464// Read reads as many bytes as possible into p.
467465func (r * messageReader ) Read (p []byte ) (int , error ) {
468- n , err := r .read (p )
466+ return r .exportedRead (p , true )
467+ }
468+
469+ func (r * messageReader ) exportedRead (p []byte , lock bool ) (int , error ) {
470+ n , err := r .read (p , lock )
469471 if err != nil {
470472 // Have to return io.EOF directly for now, we cannot wrap as errors.Is
471473 // isn't used widely yet.
@@ -477,17 +479,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
477479 return n , nil
478480}
479481
480- func (r * messageReader ) read (p []byte ) (int , error ) {
481- if r .c .readerShouldLock {
482- err := r .c .acquireLock (r .c .readerMsgCtx , r .c .readLock )
483- if err != nil {
484- return 0 , err
482+ func (r * messageReader ) readUnlocked (p []byte ) (int , error ) {
483+ return r .exportedRead (p , false )
484+ }
485+
486+ func (r * messageReader ) read (p []byte , lock bool ) (int , error ) {
487+ if lock {
488+ // If we cannot acquire the read lock, then
489+ // there is either a concurrent read or the close handshake
490+ // is proceeding.
491+ select {
492+ case r .c .readLock <- struct {}{}:
493+ defer r .c .releaseLock (r .c .readLock )
494+ default :
495+ if r .c .closing .Load () == 1 {
496+ <- r .c .closed
497+ return 0 , r .c .closeErr
498+ }
499+ return 0 , errors .New ("concurrent read detected" )
485500 }
486- defer r .c .releaseLock (r .c .readLock )
487501 }
488502
489503 if r .eof () {
490- return 0 , fmt . Errorf ("cannot use EOFed reader" )
504+ return 0 , errors . New ("cannot use EOFed reader" )
491505 }
492506
493507 if r .c .readMsgLeft <= 0 {
@@ -950,8 +964,6 @@ func (c *Conn) waitClose() error {
950964 return c .closeReceived
951965 }
952966
953- c .readerShouldLock = false
954-
955967 b := bpool .Get ()
956968 buf := b .Bytes ()
957969 buf = buf [:cap (buf )]
@@ -965,7 +977,8 @@ func (c *Conn) waitClose() error {
965977 }
966978 }
967979
968- _ , err = io .CopyBuffer (ioutil .Discard , c .activeReader , buf )
980+ r := readerFunc (c .activeReader .readUnlocked )
981+ _ , err = io .CopyBuffer (ioutil .Discard , r , buf )
969982 if err != nil {
970983 return err
971984 }
@@ -1019,6 +1032,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
10191032 }
10201033}
10211034
1035+ type readerFunc func (p []byte ) (int , error )
1036+
1037+ func (f readerFunc ) Read (p []byte ) (int , error ) {
1038+ return f (p )
1039+ }
1040+
10221041type writerFunc func (p []byte ) (int , error )
10231042
10241043func (f writerFunc ) Write (p []byte ) (int , error ) {
0 commit comments