@@ -23,24 +23,26 @@ type Conn struct {
2323
2424 msgReadLimit int64
2525
26- readClosed int64
27- closeOnce sync.Once
28- closed chan struct {}
29- closeErr error
26+ readClosed int64
27+ closeOnce sync.Once
28+ closed chan struct {}
29+ closeErrOnce sync.Once
30+ closeErr error
3031
3132 releaseOnClose func ()
3233 releaseOnMessage func ()
3334
3435 readSignal chan struct {}
3536 readBufMu sync.Mutex
36- readBuf []wsjs.MessageEvent
37+ // Max size of readBuf is 32.
38+ readBuf []wsjs.MessageEvent
3739}
3840
3941func (c * Conn ) close (err error ) {
4042 c .closeOnce .Do (func () {
4143 runtime .SetFinalizer (c , nil )
4244
43- c .closeErr = fmt . Errorf ( "websocket closed: %w" , err )
45+ c .setCloseErr ( err )
4446 close (c .closed )
4547 })
4648}
@@ -49,6 +51,8 @@ func (c *Conn) init() {
4951 c .closed = make (chan struct {})
5052 c .readSignal = make (chan struct {}, 1 )
5153 c .msgReadLimit = 32768
54+ // Capacity limited to 32 messages.
55+ c .readBuf = make ([]wsjs.MessageEvent , 0 , 32 )
5256
5357 c .releaseOnClose = c .ws .OnClose (func (e wsjs.CloseEvent ) {
5458 cerr := CloseError {
@@ -66,6 +70,12 @@ func (c *Conn) init() {
6670 c .readBufMu .Lock ()
6771 defer c .readBufMu .Unlock ()
6872
73+ if len (c .readBuf ) == cap (c .readBuf ) {
74+ c .setCloseErr (fmt .Errorf ("too many messages in buffer, cannot keep up: %v" , len (c .readBuf )))
75+ c .Close (StatusPolicyViolation , "unable to read fast enough" )
76+ return
77+ }
78+
6979 c .readBuf = append (c .readBuf , e )
7080
7181 // Lets the read goroutine know there is definitely something in readBuf.
@@ -76,11 +86,15 @@ func (c *Conn) init() {
7686 })
7787
7888 runtime .SetFinalizer (c , func (c * Conn ) {
79- c .ws . Close ( int ( StatusInternalError ), "" )
80- c .close ( errors . New ( "connection garbage collected" ) )
89+ c .setCloseErr ( errors . New ( "connection garbage collected" ) )
90+ c .closeWithInternal ( )
8191 })
8292}
8393
94+ func (c * Conn ) closeWithInternal () {
95+ c .Close (StatusInternalError , "something went wrong" )
96+ }
97+
8498// Read attempts to read a message from the connection.
8599// The maximum time spent waiting is bounded by the context.
86100func (c * Conn ) Read (ctx context.Context ) (MessageType , []byte , error ) {
@@ -113,11 +127,8 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
113127 defer c .readBufMu .Unlock ()
114128
115129 me := c .readBuf [0 ]
116- // Ensures GC can collect the message event.
117- c .readBuf [0 ] = wsjs.MessageEvent {}
118- // We do not shrink the array since it will be resized
119- // as appropriate by append in the OnMessage callback.
120- c .readBuf = c .readBuf [1 :]
130+ copy (c .readBuf , c .readBuf [1 :])
131+ c .readBuf = c .readBuf [:len (c .readBuf )- 1 ]
121132
122133 if len (c .readBuf ) > 0 {
123134 // Next time we read, we'll grab the message.
@@ -146,8 +157,10 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
146157 // to match the Go API. It can only error if the message type
147158 // is unexpected or the passed bytes contain invalid UTF-8 for
148159 // MessageText.
149- c .Close (StatusInternalError , "something went wrong" )
150- return fmt .Errorf ("failed to write: %w" , err )
160+ err := fmt .Errorf ("failed to write: %w" , err )
161+ c .setCloseErr (err )
162+ c .closeWithInternal ()
163+ return err
151164 }
152165 return nil
153166}
0 commit comments