From 8b373b8dca8d2e550cc4e4ced25c7d1bf2805945 Mon Sep 17 00:00:00 2001 From: xhe Date: Wed, 2 Nov 2022 16:43:02 +0800 Subject: [PATCH] *: handle more nil error cases (#129) --- lib/util/errors/error.go | 3 +++ lib/util/errors/error_test.go | 2 ++ lib/util/errors/werror.go | 3 +++ lib/util/errors/werror_test.go | 3 ++- pkg/proxy/net/packetio.go | 12 ++---------- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/lib/util/errors/error.go b/lib/util/errors/error.go index 076775ab..55351ae9 100644 --- a/lib/util/errors/error.go +++ b/lib/util/errors/error.go @@ -35,6 +35,9 @@ type Error struct { // WithStack will wrapping an error with stacktrace, given a default stack depth. func WithStack(err error) error { + if err == nil { + return nil + } e := &Error{err: err} e.withStackDepth(1, defaultStackDepth) return e diff --git a/lib/util/errors/error_test.go b/lib/util/errors/error_test.go index 909dd0ae..e619f38d 100644 --- a/lib/util/errors/error_test.go +++ b/lib/util/errors/error_test.go @@ -29,6 +29,8 @@ func TestStacktrace(t *testing.T) { require.Contains(t, fmt.Sprintf("%+v", e), t.Name(), "stacktrace must contain test name") require.Contains(t, fmt.Sprintf("%v", e), t.Name(), "stacktrace must contain test name") require.Contains(t, fmt.Sprintf("%+s", e), t.Name(), "stacktrace must contain test name") + + require.Nil(t, serr.WithStack(nil), "wrap nil got nil") } func TestUnwrap(t *testing.T) { diff --git a/lib/util/errors/werror.go b/lib/util/errors/werror.go index 8b7c77b6..8f96d4b7 100644 --- a/lib/util/errors/werror.go +++ b/lib/util/errors/werror.go @@ -69,6 +69,9 @@ func (e *WError) Unwrap() error { // Note that wrap nil error will get nil error. func Wrap(cerr error, uerr error) error { if cerr == nil { + return uerr + } + if uerr == nil { return nil } return &WError{ diff --git a/lib/util/errors/werror_test.go b/lib/util/errors/werror_test.go index f24a8573..860a060a 100644 --- a/lib/util/errors/werror_test.go +++ b/lib/util/errors/werror_test.go @@ -28,7 +28,8 @@ func TestWrap(t *testing.T) { require.ErrorIsf(t, e, e1, "equal to the external error") require.ErrorAsf(t, e, &e2, "unwrapping to the internal error") - require.Nil(t, serr.Wrap(nil, e2), "wrap nil got nil") + require.Equal(t, e2, serr.Wrap(nil, e2), "wrap with nil got the original") + require.Nil(t, serr.Wrap(e2, nil), "wrap nil got nil") } func TestWrapf(t *testing.T) { diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 5795ae8a..d9a37cc2 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -103,12 +103,7 @@ func NewPacketIO(conn net.Conn, opts ...PacketIOption) *PacketIO { } func (p *PacketIO) wrapErr(err error) error { - e := err - if p.wrap != nil { - e = errors.Wrap(p.wrap, err) - } - e = errors.WithStack(e) - return e + return errors.WithStack(errors.Wrap(p.wrap, err)) } // Proxy returned parsed proxy header from clients if any. @@ -248,9 +243,6 @@ func (p *PacketIO) Flush() error { func (p *PacketIO) Close() error { var errs []error - if p.wrap != nil { - errs = append(errs, p.wrap) - } /* TODO: flush when we want to smoothly exit if err := p.Flush(); err != nil { @@ -262,5 +254,5 @@ func (p *PacketIO) Close() error { errs = append(errs, err) } } - return errors.Collect(ErrCloseConn, errs...) + return p.wrapErr(errors.Collect(ErrCloseConn, errs...)) }