Skip to content

Commit 509dadd

Browse files
committed
markers: avoid panic on non-comparable structs
If an error struct implements `error` by value and the struct is incomparable, the previous implementation of `Is` would panic. This patch fixes it. Inspired from https://go-review.googlesource.com/c/go/+/175260
1 parent 2170583 commit 509dadd

File tree

2 files changed

+61
-16
lines changed

2 files changed

+61
-16
lines changed

markers/markers.go

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@ func Is(err, reference error) bool {
4646
return err == nil
4747
}
4848

49+
isComparable := reflect.TypeOf(reference).Comparable()
50+
4951
// Direct reference comparison is the fastest, and most
5052
// likely to be true, so do this first.
5153
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
52-
if c == reference {
54+
if isComparable && c == reference {
5355
return true
5456
}
5557
// Compatibility with std go errors: if the error object itself
@@ -141,10 +143,27 @@ func If(err error, pred func(err error) (interface{}, bool)) (interface{}, bool)
141143
// package location or a different type, ensure that
142144
// RegisterTypeMigration() was called prior to IsAny().
143145
func IsAny(err error, references ...error) bool {
146+
if err == nil {
147+
for _, refErr := range references {
148+
if refErr == nil {
149+
return true
150+
}
151+
}
152+
// The mark-based comparison below will never match anything if
153+
// the error is nil, so don't bother with computing the marks in
154+
// that case. This avoids the computational expense of computing
155+
// the reference marks upfront.
156+
return false
157+
}
158+
144159
// First try using direct reference comparison.
145-
for c := err; ; c = errbase.UnwrapOnce(c) {
160+
for c := err; c != nil; c = errbase.UnwrapOnce(c) {
146161
for _, refErr := range references {
147-
if c == refErr {
162+
if refErr == nil {
163+
continue
164+
}
165+
isComparable := reflect.TypeOf(refErr).Comparable()
166+
if isComparable && c == refErr {
148167
return true
149168
}
150169
// Compatibility with std go errors: if the error object itself
@@ -153,19 +172,6 @@ func IsAny(err error, references ...error) bool {
153172
return true
154173
}
155174
}
156-
if c == nil {
157-
// This special case is to support a comparison to a nil
158-
// reference.
159-
break
160-
}
161-
}
162-
163-
if err == nil {
164-
// The mark-based comparison below will never match anything if
165-
// the error is nil, so don't bother with computing the marks in
166-
// that case. This avoids the computational expense of computing
167-
// the reference marks upfront.
168-
return false
169175
}
170176

171177
// Try harder with marks.

markers/markers_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,3 +599,42 @@ func (e *errWithIs) Is(o error) bool {
599599
}
600600
return false
601601
}
602+
603+
func TestCompareUncomparable(t *testing.T) {
604+
tt := testutils.T{T: t}
605+
606+
err1 := errors.New("hello")
607+
var nilErr error
608+
f := []string{"woo"}
609+
tt.Check(markers.Is(errorUncomparable{f}, errorUncomparable{}))
610+
tt.Check(markers.IsAny(errorUncomparable{f}, errorUncomparable{}))
611+
tt.Check(markers.IsAny(errorUncomparable{f}, nilErr, errorUncomparable{}))
612+
tt.Check(!markers.Is(errorUncomparable{f}, &errorUncomparable{}))
613+
tt.Check(!markers.IsAny(errorUncomparable{f}, &errorUncomparable{}))
614+
tt.Check(!markers.IsAny(errorUncomparable{f}, nilErr, &errorUncomparable{}))
615+
tt.Check(markers.Is(&errorUncomparable{f}, errorUncomparable{}))
616+
tt.Check(markers.IsAny(&errorUncomparable{f}, errorUncomparable{}))
617+
tt.Check(markers.IsAny(&errorUncomparable{f}, nilErr, errorUncomparable{}))
618+
tt.Check(!markers.Is(&errorUncomparable{f}, &errorUncomparable{}))
619+
tt.Check(!markers.IsAny(&errorUncomparable{f}, &errorUncomparable{}))
620+
tt.Check(!markers.IsAny(&errorUncomparable{f}, nilErr, &errorUncomparable{}))
621+
tt.Check(!markers.Is(errorUncomparable{f}, err1))
622+
tt.Check(!markers.IsAny(errorUncomparable{f}, err1))
623+
tt.Check(!markers.IsAny(errorUncomparable{f}, nilErr, err1))
624+
tt.Check(!markers.Is(&errorUncomparable{f}, err1))
625+
tt.Check(!markers.IsAny(&errorUncomparable{f}, err1))
626+
tt.Check(!markers.IsAny(&errorUncomparable{f}, nilErr, err1))
627+
}
628+
629+
type errorUncomparable struct {
630+
f []string
631+
}
632+
633+
func (e errorUncomparable) Error() string {
634+
return fmt.Sprintf("uncomparable error %d", len(e.f))
635+
}
636+
637+
func (errorUncomparable) Is(target error) bool {
638+
_, ok := target.(errorUncomparable)
639+
return ok
640+
}

0 commit comments

Comments
 (0)