From 83198c2c50a6190cf9b038c485c65c5ed8b6cd3a Mon Sep 17 00:00:00 2001 From: Torkel Rogstad Date: Mon, 17 Jan 2022 14:34:50 +0100 Subject: [PATCH] assert: guard CanConvert call in backward compatible wrapper --- assert/assertion_compare.go | 2 +- assert/assertion_compare_can_convert.go | 11 ++++++ assert/assertion_compare_go1.17_test.go | 49 +++++++++++++++++++++++++ assert/assertion_compare_legacy.go | 11 ++++++ assert/assertion_compare_test.go | 4 -- 5 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 assert/assertion_compare_can_convert.go create mode 100644 assert/assertion_compare_go1.17_test.go create mode 100644 assert/assertion_compare_legacy.go diff --git a/assert/assertion_compare.go b/assert/assertion_compare.go index 96027d1ec..3bb22a971 100644 --- a/assert/assertion_compare.go +++ b/assert/assertion_compare.go @@ -306,7 +306,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) { case reflect.Struct: { // All structs enter here. We're not interested in most types. - if !obj1Value.CanConvert(timeType) { + if !canConvert(obj1Value, timeType) { break } diff --git a/assert/assertion_compare_can_convert.go b/assert/assertion_compare_can_convert.go new file mode 100644 index 000000000..1838be3fe --- /dev/null +++ b/assert/assertion_compare_can_convert.go @@ -0,0 +1,11 @@ +// +build go1.17 + +package assert + +import "reflect" + +// Wrapper around reflect.Value.CanConvert, for compatability +// reasons. +func canConvert(value reflect.Value, to reflect.Type) bool { + return value.CanConvert(to) +} diff --git a/assert/assertion_compare_go1.17_test.go b/assert/assertion_compare_go1.17_test.go new file mode 100644 index 000000000..0511e648a --- /dev/null +++ b/assert/assertion_compare_go1.17_test.go @@ -0,0 +1,49 @@ +// +build go1.17 + +package assert + +import ( + "reflect" + "testing" + "time" +) + +func TestCompare17(t *testing.T) { + type customTime time.Time + for _, currCase := range []struct { + less interface{} + greater interface{} + cType string + }{ + {less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"}, + {less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"}, + } { + resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object should be comparable for type " + currCase.cType) + } + + if resLess != compareLess { + t.Errorf("object less (%v) should be less than greater (%v) for type "+currCase.cType, + currCase.less, currCase.greater) + } + + resGreater, isComparable := compare(currCase.greater, currCase.less, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object are comparable for type " + currCase.cType) + } + + if resGreater != compareGreater { + t.Errorf("object greater should be greater than less for type " + currCase.cType) + } + + resEqual, isComparable := compare(currCase.less, currCase.less, reflect.ValueOf(currCase.less).Kind()) + if !isComparable { + t.Error("object are comparable for type " + currCase.cType) + } + + if resEqual != 0 { + t.Errorf("objects should be equal for type " + currCase.cType) + } + } +} diff --git a/assert/assertion_compare_legacy.go b/assert/assertion_compare_legacy.go new file mode 100644 index 000000000..478e953c0 --- /dev/null +++ b/assert/assertion_compare_legacy.go @@ -0,0 +1,11 @@ +// +build !go1.17 + +package assert + +import "reflect" + +// Older versions of Go does not have the reflect.Value.CanConvert +// method. +func canConvert(value reflect.Value, to reflect.Type) bool { + return false +} diff --git a/assert/assertion_compare_test.go b/assert/assertion_compare_test.go index 1af4b5da2..a38d88060 100644 --- a/assert/assertion_compare_test.go +++ b/assert/assertion_compare_test.go @@ -6,7 +6,6 @@ import ( "reflect" "runtime" "testing" - "time" ) func TestCompare(t *testing.T) { @@ -23,7 +22,6 @@ func TestCompare(t *testing.T) { type customFloat32 float32 type customFloat64 float64 type customString string - type customTime time.Time for _, currCase := range []struct { less interface{} greater interface{} @@ -54,8 +52,6 @@ func TestCompare(t *testing.T) { {less: customFloat32(1.23), greater: customFloat32(2.23), cType: "float32"}, {less: float64(1.23), greater: float64(2.34), cType: "float64"}, {less: customFloat64(1.23), greater: customFloat64(2.34), cType: "float64"}, - {less: time.Now(), greater: time.Now().Add(time.Hour), cType: "time.Time"}, - {less: customTime(time.Now()), greater: customTime(time.Now().Add(time.Hour)), cType: "time.Time"}, } { resLess, isComparable := compare(currCase.less, currCase.greater, reflect.ValueOf(currCase.less).Kind()) if !isComparable {