Skip to content

Commit

Permalink
net, backend: add more error types to disconnection error (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
djshow832 authored Mar 20, 2024
1 parent 233f456 commit 1c32cf8
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pkg/proxy/backend/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func Error2Source(err error) ErrorSource {
return SrcNone
}
// Disconnection errors may come from other errors such as ErrProxyNoBackend and ErrBackendHandshake.
// ErrClientConn and ErrBackendConn may include non-connection errors.
// ErrClientConn and ErrBackendConn may include non-connection errors such as wrong PPV2 format, TLS cert error.
if pnet.IsDisconnectError(err) {
if errors.Is(err, ErrClientConn) {
return SrcClientNetwork
Expand Down
17 changes: 17 additions & 0 deletions pkg/proxy/net/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
package net

import (
"context"
"io"
"os"
"syscall"

"github.com/pingcap/tiproxy/lib/util/errors"
)

Expand All @@ -15,3 +20,15 @@ var (
ErrCloseConn = errors.New("failed to close the connection")
ErrHandshakeTLS = errors.New("failed to complete tls handshake")
)

// IsDisconnectError returns whether the error is caused by peer disconnection.
func IsDisconnectError(err error) bool {
switch {
// Do not use os.Timeout(err) because it doesn't unwrap the error.
case errors.Is(err, io.EOF), errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ECONNRESET),
errors.Is(err, syscall.ECONNABORTED), errors.Is(err, syscall.ETIMEDOUT), errors.Is(err, os.ErrDeadlineExceeded),
errors.Is(err, context.DeadlineExceeded):
return true
}
return false
}
35 changes: 35 additions & 0 deletions pkg/proxy/net/error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2024 PingCAP, Inc.
// SPDX-License-Identifier: Apache-2.0

package net

import (
"context"
"os"
"syscall"
"testing"

"github.com/pingcap/tiproxy/lib/util/errors"
"github.com/stretchr/testify/require"
)

func TestIsDisconnectErr(t *testing.T) {
disConnErrors := []error{
syscall.ETIMEDOUT,
os.ErrDeadlineExceeded,
context.DeadlineExceeded,
errors.Wrap(errors.New("mock"), syscall.ETIMEDOUT),
errors.Wrap(syscall.ETIMEDOUT, errors.New("mock")),
}
for _, err := range disConnErrors {
require.True(t, IsDisconnectError(err))
}

otherErrors := []error{
syscall.ENOENT,
errors.New("mock"),
}
for _, err := range otherErrors {
require.False(t, IsDisconnectError(err))
}
}
23 changes: 0 additions & 23 deletions pkg/proxy/net/net_err.go

This file was deleted.

10 changes: 5 additions & 5 deletions pkg/proxy/net/packetio.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func ReadFull(prw packetReadWriter, b []byte) error {
for n := 0; n < m; {
nn, err := prw.Read(b[n:])
if err != nil {
return errors.WithStack(err)
return err
}
n += nn
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func (p *PacketIO) ApplyOpts(opts ...PacketIOption) {
}

func (p *PacketIO) wrapErr(err error) error {
return errors.WithStack(errors.Wrap(p.wrap, err))
return errors.Wrap(p.wrap, err)
}

func (p *PacketIO) LocalAddr() net.Addr {
Expand Down Expand Up @@ -371,7 +371,7 @@ func (p *PacketIO) ForwardUntil(dest *PacketIO, isEnd func(firstByte byte, first
break
}
if header, err = p.readWriter.Peek(4); err != nil {
return p.wrapErr(errors.Wrap(ErrReadConn, err))
return p.wrapErr(errors.Wrap(ErrReadConn, errors.WithStack(err)))
}
length = int(header[0]) | int(header[1])<<8 | int(header[2])<<16
}
Expand Down Expand Up @@ -405,7 +405,7 @@ func (p *PacketIO) OutPackets() uint64 {

func (p *PacketIO) Flush() error {
if err := p.readWriter.Flush(); err != nil {
return p.wrapErr(errors.Wrap(ErrFlushConn, err))
return p.wrapErr(errors.Wrap(ErrFlushConn, errors.WithStack(err)))
}
return nil
}
Expand Down Expand Up @@ -443,7 +443,7 @@ func (p *PacketIO) Close() error {
}
*/
if err := p.readWriter.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
errs = append(errs, err)
errs = append(errs, errors.WithStack(err))
}
return p.wrapErr(errors.Collect(ErrCloseConn, errs...))
}
4 changes: 2 additions & 2 deletions pkg/proxy/net/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (p *PacketIO) ServerTLSHandshake(tlsConfig *tls.Config) (tls.ConnectionStat
conn := &tlsInternalConn{p.readWriter}
tlsConn := tls.Server(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return tls.ConnectionState{}, p.wrapErr(errors.Wrap(ErrHandshakeTLS, err))
return tls.ConnectionState{}, p.wrapErr(errors.Wrap(ErrHandshakeTLS, errors.WithStack(err)))
}
p.readWriter = newTLSReadWriter(p.readWriter, tlsConn)
return tlsConn.ConnectionState(), nil
Expand All @@ -39,7 +39,7 @@ func (p *PacketIO) ClientTLSHandshake(tlsConfig *tls.Config) error {
conn := &tlsInternalConn{p.readWriter}
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return p.wrapErr(errors.Wrap(ErrHandshakeTLS, err))
return p.wrapErr(errors.Wrap(ErrHandshakeTLS, errors.WithStack(err)))
}
p.readWriter = newTLSReadWriter(p.readWriter, tlsConn)
return nil
Expand Down

0 comments on commit 1c32cf8

Please sign in to comment.