Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/gorilla/websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
savsgio committed Dec 4, 2024
2 parents bc9f200 + 5e00238 commit b873e76
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 64 deletions.
18 changes: 15 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,15 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
})
}

// Close the network connection when returning an error. The variable
// netConn is set to nil before the success return at the end of the
// function.
defer func() {
if netConn != nil {
netConn.Close()
// It's safe to ignore the error from Close() because this code is
// only executed when returning a more important error to the
// application.
_ = netConn.Close()
}
}()

Expand Down Expand Up @@ -398,8 +404,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
resp.Body = io.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")

netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
if err := netConn.SetDeadline(time.Time{}); err != nil {
return nil, resp, err
}

// Success! Set netConn to nil to stop the deferred function above from
// closing the network connection.
netConn = nil

return conn, resp, nil
}

Expand Down
10 changes: 5 additions & 5 deletions client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ func TestNoUpgrade(t *testing.T) {
}
resp.Body.Close()
if u := resp.Header.Get("Upgrade"); u != "websocket" {
t.Errorf("Uprade response header is %q, want %q", u, "websocket")
t.Errorf("Upgrade response header is %q, want %q", u, "websocket")
}
if resp.StatusCode != http.StatusUpgradeRequired {
t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusUpgradeRequired)
Expand Down Expand Up @@ -578,7 +578,7 @@ func TestRespOnBadHandshake(t *testing.T) {

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(expectedStatus)
io.WriteString(w, expectedBody)
_, _ = io.WriteString(w, expectedBody)
}))
defer s.Close()

Expand Down Expand Up @@ -828,7 +828,7 @@ func TestSocksProxyDial(t *testing.T) {
}
defer c1.Close()

c1.SetDeadline(time.Now().Add(30 * time.Second))
_ = c1.SetDeadline(time.Now().Add(30 * time.Second))

buf := make([]byte, 32)
if _, err := io.ReadFull(c1, buf[:3]); err != nil {
Expand Down Expand Up @@ -867,10 +867,10 @@ func TestSocksProxyDial(t *testing.T) {
defer c2.Close()
done := make(chan struct{})
go func() {
io.Copy(c1, c2)
_, _ = io.Copy(c1, c2)
close(done)
}()
io.Copy(c2, c1)
_, _ = io.Copy(c2, c1)
<-done
}()

Expand Down
6 changes: 5 additions & 1 deletion compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
"\x01\x00\x00\xff\xff"

fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
mr := io.MultiReader(r, strings.NewReader(tail))
if err := fr.(flate.Resetter).Reset(mr, nil); err != nil {
// Reset never fails, but handle error in case that changes.
fr = flate.NewReader(mr)
}
return &flateReadWrapper{fr}
}

Expand Down
6 changes: 3 additions & 3 deletions compression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestTruncWriter(t *testing.T) {
if m > n {
m = n
}
w.Write(p[:m])
_, _ = w.Write(p[:m])
p = p[m:]
}
if b.String() != data[:len(data)-len(w.p)] {
Expand All @@ -46,7 +46,7 @@ func BenchmarkWriteNoCompression(b *testing.B) {
messages := textMessages(100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}
Expand All @@ -59,7 +59,7 @@ func BenchmarkWriteWithCompression(b *testing.B) {
c.newCompressionWriter = compressNoContextTakeover
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
_ = c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}
Expand Down
64 changes: 39 additions & 25 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ func (c *Conn) read(n int) ([]byte, error) {
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
// Discard is guaranteed to succeed because the number of bytes to discard
// is less than or equal to the number of bytes buffered.
_, _ = c.br.Discard(len(p))
return p, err
}

Expand All @@ -412,7 +414,9 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return ErrNilNetConn
}

c.conn.SetWriteDeadline(deadline)
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
if len(buf1) == 0 {
_, err = c.conn.Write(buf0)
} else {
Expand All @@ -422,7 +426,7 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return c.writeFatal(err)
}
if frameType == CloseMessage {
c.writeFatal(ErrCloseSent)
_ = c.writeFatal(ErrCloseSent)
}
return nil
}
Expand Down Expand Up @@ -464,21 +468,27 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
maskBytes(key, 0, buf[6:])
}

d := 1000 * time.Hour
if !deadline.IsZero() {
d = deadline.Sub(time.Now())
if deadline.IsZero() {
// No timeout for zero time.
<-c.mu
} else {
d := time.Until(deadline)
if d < 0 {
return errWriteTimeout
}
select {
case <-c.mu:
default:
timer := time.NewTimer(d)
select {
case <-c.mu:
timer.Stop()
case <-timer.C:
return errWriteTimeout
}
}
}

timer := time.NewTimer(d)
select {
case <-c.mu:
timer.Stop()
case <-timer.C:
return errWriteTimeout
}
defer func() { c.mu <- struct{}{} }()

c.writeErrMu.Lock()
Expand All @@ -491,13 +501,14 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er
return ErrNilNetConn
}

c.conn.SetWriteDeadline(deadline)
_, err = c.conn.Write(buf)
if err != nil {
if err := c.conn.SetWriteDeadline(deadline); err != nil {
return c.writeFatal(err)
}
if _, err = c.conn.Write(buf); err != nil {
return c.writeFatal(err)
}
if messageType == CloseMessage {
c.writeFatal(ErrCloseSent)
_ = c.writeFatal(ErrCloseSent)
}
return err
}
Expand Down Expand Up @@ -670,7 +681,7 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
}

if final {
w.endMessage(errWriteClosed)
_ = w.endMessage(errWriteClosed)
return nil
}

Expand Down Expand Up @@ -865,7 +876,7 @@ func (c *Conn) advanceFrame() (int, error) {
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))
_ = c.setReadRemaining(int64(p[1] & 0x7f)) // will not fail because argument is >= 0

c.readDecompress = false
if rsv1 {
Expand Down Expand Up @@ -970,7 +981,8 @@ func (c *Conn) advanceFrame() (int, error) {
}

if c.readLimit > 0 && c.readLength > c.readLimit {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
// Make a best effort to send a close message describing the problem.
_ = c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
return noFrame, ErrReadLimit
}

Expand All @@ -982,7 +994,7 @@ func (c *Conn) advanceFrame() (int, error) {
var payload []byte
if c.readRemaining > 0 {
payload, err = c.read(int(c.readRemaining))
c.setReadRemaining(0)
_ = c.setReadRemaining(0) // will not fail because argument is >= 0
if err != nil {
return noFrame, err
}
Expand Down Expand Up @@ -1032,7 +1044,8 @@ func (c *Conn) handleProtocolError(message string) error {
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
// Make a best effor to send a close message describing the problem.
_ = c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}

Expand Down Expand Up @@ -1111,7 +1124,7 @@ func (r *messageReader) Read(b []byte) (int, error) {
}
rem := c.readRemaining
rem -= int64(n)
c.setReadRemaining(rem)
_ = c.setReadRemaining(rem) // rem is guaranteed to be >= 0
if c.readRemaining > 0 && c.readErr == io.EOF {
c.readErr = errUnexpectedEOF
}
Expand Down Expand Up @@ -1211,7 +1224,8 @@ func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil {
h = func(code int, text string) error {
message := FormatCloseMessage(code, "")
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
// Make a best effor to send the close message.
_ = c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil
}
}
Expand Down Expand Up @@ -1239,7 +1253,7 @@ func (c *Conn) SetPingHandler(h func(appData string) error) {
}
if h == nil {
h = func(message string) error {
// Make a best effort to send the pong mesage.
// Make a best effort to send the pong message.
_ = c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait))
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions conn_broadcast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ func (b *broadcastBench) makeConns(numConns int) {
select {
case msg := <-c.msgCh:
if msg.prepared != nil {
c.conn.WritePreparedMessage(msg.prepared)
_ = c.conn.WritePreparedMessage(msg.prepared)
} else {
c.conn.WriteMessage(TextMessage, msg.payload)
_ = c.conn.WriteMessage(TextMessage, msg.payload)
}
val := atomic.AddInt32(&b.count, 1)
if val%int32(numConns) == 0 {
Expand Down
Loading

0 comments on commit b873e76

Please sign in to comment.