@@ -98,82 +98,106 @@ func CloseStatus(err error) StatusCode {
9898//
9999// Close will unblock all goroutines interacting with the connection once
100100// complete.
101- func (c * Conn ) Close (code StatusCode , reason string ) error {
102- defer c .wg .Wait ()
103- return c .closeHandshake (code , reason )
101+ func (c * Conn ) Close (code StatusCode , reason string ) (err error ) {
102+ defer errd .Wrap (& err , "failed to close WebSocket" )
103+
104+ if ! c .casClosing () {
105+ err = c .waitGoroutines ()
106+ if err != nil {
107+ return err
108+ }
109+ return net .ErrClosed
110+ }
111+ defer func () {
112+ if errors .Is (err , net .ErrClosed ) {
113+ err = nil
114+ }
115+ }()
116+
117+ err = c .closeHandshake (code , reason )
118+
119+ err2 := c .close ()
120+ if err == nil && err2 != nil {
121+ err = err2
122+ }
123+
124+ err2 = c .waitGoroutines ()
125+ if err == nil && err2 != nil {
126+ err = err2
127+ }
128+
129+ return err
104130}
105131
106132// CloseNow closes the WebSocket connection without attempting a close handshake.
107133// Use when you do not want the overhead of the close handshake.
108134func (c * Conn ) CloseNow () (err error ) {
109- defer c .wg .Wait ()
110135 defer errd .Wrap (& err , "failed to close WebSocket" )
111136
112- if c .isClosed () {
137+ if ! c .casClosing () {
138+ err = c .waitGoroutines ()
139+ if err != nil {
140+ return err
141+ }
113142 return net .ErrClosed
114143 }
144+ defer func () {
145+ if errors .Is (err , net .ErrClosed ) {
146+ err = nil
147+ }
148+ }()
115149
116- c .close (nil )
117- c .closeMu .Lock ()
118- defer c .closeMu .Unlock ()
119- return c .closeErr
120- }
121-
122- func (c * Conn ) closeHandshake (code StatusCode , reason string ) (err error ) {
123- defer errd .Wrap (& err , "failed to close WebSocket" )
124-
125- writeErr := c .writeClose (code , reason )
126- closeHandshakeErr := c .waitCloseHandshake ()
150+ err = c .close ()
127151
128- if writeErr != nil {
129- return writeErr
152+ err2 := c .waitGoroutines ()
153+ if err == nil && err2 != nil {
154+ err = err2
130155 }
156+ return err
157+ }
131158
132- if CloseStatus (closeHandshakeErr ) == - 1 && ! errors .Is (net .ErrClosed , closeHandshakeErr ) {
133- return closeHandshakeErr
159+ func (c * Conn ) closeHandshake (code StatusCode , reason string ) error {
160+ err := c .writeClose (code , reason )
161+ if err != nil {
162+ return err
134163 }
135164
165+ err = c .waitCloseHandshake ()
166+ if CloseStatus (err ) != code {
167+ return err
168+ }
136169 return nil
137170}
138171
139172func (c * Conn ) writeClose (code StatusCode , reason string ) error {
140- c .closeMu .Lock ()
141- wroteClose := c .wroteClose
142- c .wroteClose = true
143- c .closeMu .Unlock ()
144- if wroteClose {
145- return net .ErrClosed
146- }
147-
148173 ce := CloseError {
149174 Code : code ,
150175 Reason : reason ,
151176 }
152177
153178 var p []byte
154- var marshalErr error
179+ var err error
155180 if ce .Code != StatusNoStatusRcvd {
156- p , marshalErr = ce .bytes ()
157- }
158-
159- writeErr := c .writeControl (context .Background (), opClose , p )
160- if CloseStatus (writeErr ) != - 1 {
161- // Not a real error if it's due to a close frame being received.
162- writeErr = nil
181+ p , err = ce .bytes ()
182+ if err != nil {
183+ return err
184+ }
163185 }
164186
165- // We do this after in case there was an error writing the close frame.
166- c . setCloseErr ( fmt . Errorf ( "sent close frame: %w" , ce ) )
187+ ctx , cancel := context . WithTimeout ( context . Background (), time . Second * 5 )
188+ defer cancel ( )
167189
168- if marshalErr != nil {
169- return marshalErr
190+ err = c .writeControl (ctx , opClose , p )
191+ // If the connection closed as we're writing we ignore the error as we might
192+ // have written the close frame, the peer responded and then someone else read it
193+ // and closed the connection.
194+ if err != nil && ! errors .Is (err , net .ErrClosed ) {
195+ return err
170196 }
171- return writeErr
197+ return nil
172198}
173199
174200func (c * Conn ) waitCloseHandshake () error {
175- defer c .close (nil )
176-
177201 ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
178202 defer cancel ()
179203
@@ -209,6 +233,36 @@ func (c *Conn) waitCloseHandshake() error {
209233 }
210234}
211235
236+ func (c * Conn ) waitGoroutines () error {
237+ t := time .NewTimer (time .Second * 15 )
238+ defer t .Stop ()
239+
240+ select {
241+ case <- c .timeoutLoopDone :
242+ case <- t .C :
243+ return errors .New ("failed to wait for timeoutLoop goroutine to exit" )
244+ }
245+
246+ c .closeReadMu .Lock ()
247+ ctx := c .closeReadCtx
248+ c .closeReadMu .Unlock ()
249+ if ctx != nil {
250+ select {
251+ case <- ctx .Done ():
252+ case <- t .C :
253+ return errors .New ("failed to wait for close read goroutine to exit" )
254+ }
255+ }
256+
257+ select {
258+ case <- c .closed :
259+ case <- t .C :
260+ return errors .New ("failed to wait for connection to be closed" )
261+ }
262+
263+ return nil
264+ }
265+
212266func parseClosePayload (p []byte ) (CloseError , error ) {
213267 if len (p ) == 0 {
214268 return CloseError {
@@ -279,16 +333,14 @@ func (ce CloseError) bytesErr() ([]byte, error) {
279333 return buf , nil
280334}
281335
282- func (c * Conn ) setCloseErr ( err error ) {
336+ func (c * Conn ) casClosing () bool {
283337 c .closeMu .Lock ()
284- c .setCloseErrLocked (err )
285- c .closeMu .Unlock ()
286- }
287-
288- func (c * Conn ) setCloseErrLocked (err error ) {
289- if c .closeErr == nil && err != nil {
290- c .closeErr = fmt .Errorf ("WebSocket closed: %w" , err )
338+ defer c .closeMu .Unlock ()
339+ if ! c .closing {
340+ c .closing = true
341+ return true
291342 }
343+ return false
292344}
293345
294346func (c * Conn ) isClosed () bool {
0 commit comments