-
Notifications
You must be signed in to change notification settings - Fork 896
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GODRIVER-2929 Add the ability to join multiple errors into one. (#1370)
- Loading branch information
1 parent
ec38db6
commit 848b7c2
Showing
4 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright (C) MongoDB, Inc. 2023-present. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
// not use this file except in compliance with the License. You may obtain | ||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
package errutil | ||
|
||
import "errors" | ||
|
||
// join is a Go 1.13-1.19 compatible version of [errors.Join]. It is only called | ||
// by Join in join_go1.19.go. It is included here in a file without build | ||
// constraints only for testing purposes. | ||
// | ||
// It is heavily based on Join from | ||
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go | ||
func join(errs ...error) error { | ||
n := 0 | ||
for _, err := range errs { | ||
if err != nil { | ||
n++ | ||
} | ||
} | ||
if n == 0 { | ||
return nil | ||
} | ||
e := &joinError{ | ||
errs: make([]error, 0, n), | ||
} | ||
for _, err := range errs { | ||
if err != nil { | ||
e.errs = append(e.errs, err) | ||
} | ||
} | ||
return e | ||
} | ||
|
||
// joinError is a Go 1.13-1.19 compatible joinable error type. Its error | ||
// message is identical to [errors.Join], but it implements "Unwrap() error" | ||
// instead of "Unwrap() []error". | ||
// | ||
// It is heavily based on the joinError from | ||
// https://cs.opensource.google/go/go/+/refs/tags/go1.21.0:src/errors/join.go | ||
type joinError struct { | ||
errs []error | ||
} | ||
|
||
func (e *joinError) Error() string { | ||
var b []byte | ||
for i, err := range e.errs { | ||
if i > 0 { | ||
b = append(b, '\n') | ||
} | ||
b = append(b, err.Error()...) | ||
} | ||
return string(b) | ||
} | ||
|
||
// Unwrap returns another joinError with the same errors as the current | ||
// joinError except the first error in the slice. Continuing to call Unwrap | ||
// on each returned error will increment through every error in the slice. The | ||
// resulting behavior when using [errors.Is] and [errors.As] is similar to an | ||
// error created using [errors.Join] in Go 1.20+. | ||
func (e *joinError) Unwrap() error { | ||
if len(e.errs) == 1 { | ||
return e.errs[0] | ||
} | ||
return &joinError{errs: e.errs[1:]} | ||
} | ||
|
||
// Is calls [errors.Is] with the first error in the slice. | ||
func (e *joinError) Is(target error) bool { | ||
if len(e.errs) == 0 { | ||
return false | ||
} | ||
return errors.Is(e.errs[0], target) | ||
} | ||
|
||
// As calls [errors.As] with the first error in the slice. | ||
func (e *joinError) As(target interface{}) bool { | ||
if len(e.errs) == 0 { | ||
return false | ||
} | ||
return errors.As(e.errs[0], target) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Copyright (C) MongoDB, Inc. 2023-present. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
// not use this file except in compliance with the License. You may obtain | ||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
//go:build !go1.20 | ||
// +build !go1.20 | ||
|
||
package errutil | ||
|
||
// Join returns an error that wraps the given errors. Any nil error values are | ||
// discarded. Join returns nil if every value in errs is nil. The error formats | ||
// as the concatenation of the strings obtained by calling the Error method of | ||
// each element of errs, with a newline between each string. | ||
// | ||
// A non-nil error returned by Join implements the "Unwrap() error" method. | ||
func Join(errs ...error) error { | ||
return join(errs...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
// Copyright (C) MongoDB, Inc. 2023-present. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
// not use this file except in compliance with the License. You may obtain | ||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
//go:build go1.20 | ||
// +build go1.20 | ||
|
||
package errutil | ||
|
||
import "errors" | ||
|
||
// Join calls [errors.Join]. | ||
func Join(errs ...error) error { | ||
return errors.Join(errs...) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
// Copyright (C) MongoDB, Inc. 2023-present. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); you may | ||
// not use this file except in compliance with the License. You may obtain | ||
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
package errutil | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"testing" | ||
|
||
"go.mongodb.org/mongo-driver/internal/assert" | ||
) | ||
|
||
// TestJoin_Nil asserts that join returns a nil error for the same inputs that | ||
// [errors.Join] returns a nil error. | ||
func TestJoin_Nil(t *testing.T) { | ||
t.Parallel() | ||
|
||
assert.Equal(t, errors.Join(), join(), "errors.Join() != join()") | ||
assert.Equal(t, errors.Join(nil), join(nil), "errors.Join(nil) != join(nil)") | ||
assert.Equal(t, errors.Join(nil, nil), join(nil, nil), "errors.Join(nil, nil) != join(nil, nil)") | ||
} | ||
|
||
// TestJoin_Error asserts that join returns an error with the same error message | ||
// as the error returned by [errors.Join]. | ||
func TestJoin_Error(t *testing.T) { | ||
t.Parallel() | ||
|
||
err1 := errors.New("err1") | ||
err2 := errors.New("err2") | ||
|
||
tests := []struct { | ||
desc string | ||
errs []error | ||
}{{ | ||
desc: "single error", | ||
errs: []error{err1}, | ||
}, { | ||
desc: "two errors", | ||
errs: []error{err1, err2}, | ||
}, { | ||
desc: "two errors and a nil value", | ||
errs: []error{err1, nil, err2}, | ||
}} | ||
|
||
for _, test := range tests { | ||
test := test // Capture range variable. | ||
|
||
t.Run(test.desc, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
want := errors.Join(test.errs...).Error() | ||
got := join(test.errs...).Error() | ||
assert.Equal(t, | ||
want, | ||
got, | ||
"errors.Join().Error() != join().Error() for input %v", | ||
test.errs) | ||
}) | ||
} | ||
} | ||
|
||
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically | ||
// to the error returned by [errors.Join] when passed to [errors.Is]. | ||
func TestJoin_ErrorsIs(t *testing.T) { | ||
t.Parallel() | ||
|
||
err1 := errors.New("err1") | ||
err2 := errors.New("err2") | ||
|
||
tests := []struct { | ||
desc string | ||
errs []error | ||
target error | ||
}{{ | ||
desc: "one error with a matching target", | ||
errs: []error{err1}, | ||
target: err1, | ||
}, { | ||
desc: "one error with a non-matching target", | ||
errs: []error{err1}, | ||
target: err2, | ||
}, { | ||
desc: "nil error", | ||
errs: []error{nil}, | ||
target: err1, | ||
}, { | ||
desc: "no errors", | ||
errs: []error{}, | ||
target: err1, | ||
}, { | ||
desc: "two different errors with a matching target", | ||
errs: []error{err1, err2}, | ||
target: err2, | ||
}, { | ||
desc: "two identical errors with a matching target", | ||
errs: []error{err1, err1}, | ||
target: err1, | ||
}, { | ||
desc: "wrapped error with a matching target", | ||
errs: []error{fmt.Errorf("error: %w", err1)}, | ||
target: err1, | ||
}, { | ||
desc: "nested joined error with a matching target", | ||
errs: []error{err1, join(err2, errors.New("nope"))}, | ||
target: err2, | ||
}, { | ||
desc: "nested joined error with no matching targets", | ||
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))}, | ||
target: err2, | ||
}, { | ||
desc: "nested joined error with a wrapped matching target", | ||
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2}, | ||
target: err1, | ||
}, { | ||
desc: "context.DeadlineExceeded", | ||
errs: []error{err1, nil, context.DeadlineExceeded, err2}, | ||
target: context.DeadlineExceeded, | ||
}, { | ||
desc: "wrapped context.DeadlineExceeded", | ||
errs: []error{err1, nil, fmt.Errorf("error: %w", context.DeadlineExceeded), err2}, | ||
target: context.DeadlineExceeded, | ||
}} | ||
|
||
for _, test := range tests { | ||
test := test // Capture range variable. | ||
|
||
t.Run(test.desc, func(t *testing.T) { | ||
// Assert that top-level errors returned by errors.Join and join | ||
// behave the same with errors.Is. | ||
want := errors.Join(test.errs...) | ||
got := join(test.errs...) | ||
assert.Equal(t, | ||
errors.Is(want, test.target), | ||
errors.Is(got, test.target), | ||
"errors.Join() and join() behave differently with errors.Is") | ||
|
||
// Assert that wrapped errors returned by errors.Join and join | ||
// behave the same with errors.Is. | ||
want = fmt.Errorf("error: %w", errors.Join(test.errs...)) | ||
got = fmt.Errorf("error: %w", join(test.errs...)) | ||
assert.Equal(t, | ||
errors.Is(want, test.target), | ||
errors.Is(got, test.target), | ||
"errors.Join() and join(), when wrapped, behave differently with errors.Is") | ||
}) | ||
} | ||
} | ||
|
||
type errType1 struct{} | ||
|
||
func (errType1) Error() string { return "" } | ||
|
||
type errType2 struct{} | ||
|
||
func (errType2) Error() string { return "" } | ||
|
||
// TestJoin_ErrorsIs asserts that join returns an error that behaves identically | ||
// to the error returned by [errors.Join] when passed to [errors.As]. | ||
func TestJoin_ErrorsAs(t *testing.T) { | ||
t.Parallel() | ||
|
||
err1 := errType1{} | ||
err2 := errType2{} | ||
|
||
tests := []struct { | ||
desc string | ||
errs []error | ||
target interface{} | ||
}{{ | ||
desc: "one error with a matching target", | ||
errs: []error{err1}, | ||
target: &errType1{}, | ||
}, { | ||
desc: "one error with a non-matching target", | ||
errs: []error{err1}, | ||
target: &errType2{}, | ||
}, { | ||
desc: "nil error", | ||
errs: []error{nil}, | ||
target: &errType1{}, | ||
}, { | ||
desc: "no errors", | ||
errs: []error{}, | ||
target: &errType1{}, | ||
}, { | ||
desc: "two different errors with a matching target", | ||
errs: []error{err1, err2}, | ||
target: &errType2{}, | ||
}, { | ||
desc: "two identical errors with a matching target", | ||
errs: []error{err1, err1}, | ||
target: &errType1{}, | ||
}, { | ||
desc: "wrapped error with a matching target", | ||
errs: []error{fmt.Errorf("error: %w", err1)}, | ||
target: &errType1{}, | ||
}, { | ||
desc: "nested joined error with a matching target", | ||
errs: []error{err1, join(err2, errors.New("nope"))}, | ||
target: &errType2{}, | ||
}, { | ||
desc: "nested joined error with no matching targets", | ||
errs: []error{err1, join(errors.New("nope"), errors.New("nope 2"))}, | ||
target: &errType2{}, | ||
}, { | ||
desc: "nested joined error with a wrapped matching target", | ||
errs: []error{join(fmt.Errorf("error: %w", err1), errors.New("nope")), err2}, | ||
target: &errType1{}, | ||
}, { | ||
desc: "context.DeadlineExceeded", | ||
errs: []error{err1, nil, context.DeadlineExceeded, err2}, | ||
target: &errType2{}, | ||
}} | ||
|
||
for _, test := range tests { | ||
test := test // Capture range variable. | ||
|
||
t.Run(test.desc, func(t *testing.T) { | ||
// Assert that top-level errors returned by errors.Join and join | ||
// behave the same with errors.As. | ||
want := errors.Join(test.errs...) | ||
got := join(test.errs...) | ||
assert.Equal(t, | ||
errors.As(want, test.target), | ||
errors.As(got, test.target), | ||
"errors.Join() and join() behave differently with errors.As") | ||
|
||
// Assert that wrapped errors returned by errors.Join and join | ||
// behave the same with errors.As. | ||
want = fmt.Errorf("error: %w", errors.Join(test.errs...)) | ||
got = fmt.Errorf("error: %w", join(test.errs...)) | ||
assert.Equal(t, | ||
errors.As(want, test.target), | ||
errors.As(got, test.target), | ||
"errors.Join() and join(), when wrapped, behave differently with errors.As") | ||
}) | ||
} | ||
} |