Skip to content

Commit

Permalink
Merge pull request #151 from hpidcock/multichecker
Browse files Browse the repository at this point in the history
#151

MultiChecker allows you to perform a DeepEquals but have bespoke checkers based on path matching.

For example, this allows the value at "b" to be ignored.
```
a1 := map[string]string{"a": "a", "b": "b", "c": "c"}
a2 := map[string]string{"a": "a", "b": "bbbb", "c": "c"}

checker := jc.NewMultiChecker().Add(`["b"]`, jc.Ignore)
c.Check(a1, checker, a2)
```

This allows the second element to have the SameContents check applied, ignoring order of elements:
```
a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}}
a2 := [][]string{{"a", "b", "c"}, {"e", "c", "d"}}

checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, jc.ExpectedValue)
c.Check(a1, checker, a2)
```
  • Loading branch information
jujubot authored Jun 8, 2020
2 parents 6c8c298 + 543482a commit e4eedbc
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 8 deletions.
13 changes: 13 additions & 0 deletions checkers/bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,16 @@ func (checker *deepEqualsChecker) Check(params []interface{}, names []string) (r
}
return true, ""
}

type ignoreChecker struct {
*gc.CheckerInfo
}

// Ignore always succeeds.
var Ignore gc.Checker = &ignoreChecker{
&gc.CheckerInfo{Name: "Ignore", Params: []string{"obtained"}},
}

func (checker *ignoreChecker) Check(params []interface{}, names []string) (result bool, error string) {
return true, ""
}
67 changes: 59 additions & 8 deletions checkers/deepequal.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func printable(v reflect.Value) interface{} {
// Tests for deep equality using reflected types. The map argument tracks
// comparisons that have already been seen, which allows short circuiting on
// recursive types.
func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int) (ok bool, err error) {
func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, depth int, customCheckFunc CustomCheckFunc) (ok bool, err error) {
errorf := func(f string, a ...interface{}) error {
return &mismatchError{
v1: v1,
Expand Down Expand Up @@ -105,6 +105,13 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
visited[v] = true
}

if customCheckFunc != nil && v1.CanInterface() && v2.CanInterface() {
useDefault, equal, err := customCheckFunc(path, v1.Interface(), v2.Interface())
if !useDefault {
return equal, err
}
}

switch v1.Kind() {
case reflect.Array:
if v1.Len() != v2.Len() {
Expand All @@ -114,7 +121,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
for i := 0; i < v1.Len(); i++ {
if ok, err := deepValueEqual(
fmt.Sprintf("%s[%d]", path, i),
v1.Index(i), v2.Index(i), visited, depth+1); !ok {
v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand All @@ -130,7 +137,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
for i := 0; i < v1.Len(); i++ {
if ok, err := deepValueEqual(
fmt.Sprintf("%s[%d]", path, i),
v1.Index(i), v2.Index(i), visited, depth+1); !ok {
v1.Index(i), v2.Index(i), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand All @@ -142,9 +149,9 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
}
return true, nil
}
return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1)
return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc)
case reflect.Ptr:
return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1)
return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visited, depth+1, customCheckFunc)
case reflect.Struct:
if v1.Type() == timeType {
// Special case for time - we ignore the time zone.
Expand All @@ -157,7 +164,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
}
for i, n := 0, v1.NumField(); i < n; i++ {
path := path + "." + v1.Type().Field(i).Name
if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1); !ok {
if ok, err := deepValueEqual(path, v1.Field(i), v2.Field(i), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand All @@ -179,7 +186,7 @@ func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
} else {
p = path + "[someKey]"
}
if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1); !ok {
if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIndex(k), visited, depth+1, customCheckFunc); !ok {
return false, err
}
}
Expand Down Expand Up @@ -263,9 +270,53 @@ func DeepEqual(a1, a2 interface{}) (bool, error) {
if v1.Type() != v2.Type() {
return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
}
return deepValueEqual("", v1, v2, make(map[visit]bool), 0)
return deepValueEqual("", v1, v2, make(map[visit]bool), 0, nil)
}

// DeepEqualWithCustomCheck tests for deep equality. It uses normal == equality where
// possible but will scan elements of arrays, slices, maps, and fields
// of structs. In maps, keys are compared with == but elements use deep
// equality. DeepEqual correctly handles recursive types. Functions are
// equal only if they are both nil.
//
// DeepEqual differs from reflect.DeepEqual in two ways:
// - an empty slice is considered equal to a nil slice.
// - two time.Time values that represent the same instant
// but with different time zones are considered equal.
//
// If the two values compare unequal, the resulting error holds the
// first difference encountered.
//
// If both values are interface-able and customCheckFunc is non nil,
// customCheckFunc will be invoked. If it returns useDefault as true, the
// DeepEqual continues, otherwise the result of the customCheckFunc is used.
func DeepEqualWithCustomCheck(a1 interface{}, a2 interface{}, customCheckFunc CustomCheckFunc) (bool, error) {
errorf := func(f string, a ...interface{}) error {
return &mismatchError{
v1: reflect.ValueOf(a1),
v2: reflect.ValueOf(a2),
path: "",
how: fmt.Sprintf(f, a...),
}
}
if a1 == nil || a2 == nil {
if a1 == a2 {
return true, nil
}
return false, errorf("nil vs non-nil mismatch")
}
v1 := reflect.ValueOf(a1)
v2 := reflect.ValueOf(a2)
if v1.Type() != v2.Type() {
return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Type())
}
return deepValueEqual("", v1, v2, make(map[visit]bool), 0, customCheckFunc)
}

// CustomCheckFunc should return true for useDefault if DeepEqualWithCustomCheck should behave like DeepEqual.
// Otherwise the result of the CustomCheckFunc is used.
type CustomCheckFunc func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error)

// interfaceOf returns v.Interface() even if v.CanInterface() == false.
// This enables us to call fmt.Printf on a value even if it's derived
// from inside an unexported field.
Expand Down
113 changes: 113 additions & 0 deletions checkers/multichecker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2020 Canonical Ltd.
// Licensed under the LGPLv3, see LICENCE file for details.

package checkers

import (
"fmt"
"regexp"

gc "gopkg.in/check.v1"
)

// MultiChecker is a deep checker that by default matches for equality.
// But checks can be overriden based on path (either explicit match or regexp)
type MultiChecker struct {
*gc.CheckerInfo
checks map[string]multiCheck
regexChecks []regexCheck
}

type multiCheck struct {
checker gc.Checker
args []interface{}
}

type regexCheck struct {
multiCheck
regex *regexp.Regexp
}

// NewMultiChecker creates a MultiChecker which is a deep checker that by default matches for equality.
// But checks can be overriden based on path (either explicit match or regexp)
func NewMultiChecker() *MultiChecker {
return &MultiChecker{
CheckerInfo: &gc.CheckerInfo{Name: "MultiChecker", Params: []string{"obtained", "expected"}},
checks: make(map[string]multiCheck),
}
}

// Add an explict checker by path.
func (checker *MultiChecker) Add(path string, c gc.Checker, args ...interface{}) *MultiChecker {
checker.checks[path] = multiCheck{
checker: c,
args: args,
}
return checker
}

// AddRegex exception which matches path with regex.
func (checker *MultiChecker) AddRegex(pathRegex string, c gc.Checker, args ...interface{}) *MultiChecker {
checker.regexChecks = append(checker.regexChecks, regexCheck{
multiCheck: multiCheck{
checker: c,
args: args,
},
regex: regexp.MustCompile("^" + pathRegex + "$"),
})
return checker
}

// Check for go check Checker interface.
func (checker *MultiChecker) Check(params []interface{}, names []string) (result bool, errStr string) {
customCheckFunc := func(path string, a1 interface{}, a2 interface{}) (useDefault bool, equal bool, err error) {
var mc *multiCheck
if c, ok := checker.checks[path]; ok {
mc = &c
} else {
for _, v := range checker.regexChecks {
if v.regex.MatchString(path) {
mc = &v.multiCheck
break
}
}
}
if mc == nil {
return true, false, nil
}

params := append([]interface{}{a1}, mc.args...)
info := mc.checker.Info()
if len(params) < len(info.Params) {
return false, false, fmt.Errorf("Wrong number of parameters for %s: want %d, got %d", info.Name, len(info.Params), len(params)+1)
}
// Copy since it may be mutated by Check.
names := append([]string{}, info.Params...)

// Trim to the expected params len.
params = params[:len(info.Params)]

// Perform substitution
for i, v := range params {
if v == ExpectedValue {
params[i] = a2
}
}

result, errStr := mc.checker.Check(params, names)
if result {
return false, true, nil
}
if path == "" {
path = "top level"
}
return false, false, fmt.Errorf("mismatch at %s: %s", path, errStr)
}
if ok, err := DeepEqualWithCustomCheck(params[0], params[1], customCheckFunc); !ok {
return false, err.Error()
}
return true, ""
}

// ExpectedValue if passed to MultiChecker.Add or MultiChecker.AddRegex, will be substituded with the expected value.
var ExpectedValue = &struct{}{}
71 changes: 71 additions & 0 deletions checkers/multichecker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package checkers_test

import (
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"
)

type MultiCheckerSuite struct{}

var _ = gc.Suite(&MultiCheckerSuite{})

func (s *MultiCheckerSuite) TestDeepEquals(c *gc.C) {
for i, test := range deepEqualTests {
c.Logf("test %d. %v == %v is %v", i, test.a, test.b, test.eq)
result, msg := jc.NewMultiChecker().Check([]interface{}{test.a, test.b}, nil)
c.Check(result, gc.Equals, test.eq)
if test.eq {
c.Check(msg, gc.Equals, "")
} else {
c.Check(msg, gc.Not(gc.Equals), "")
}
}
}

func (s *MultiCheckerSuite) TestArray(c *gc.C) {
a1 := []string{"a", "b", "c"}
a2 := []string{"a", "bbb", "c"}

checker := jc.NewMultiChecker().Add("[1]", jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestMap(c *gc.C) {
a1 := map[string]string{"a": "a", "b": "b", "c": "c"}
a2 := map[string]string{"a": "a", "b": "bbbb", "c": "c"}

checker := jc.NewMultiChecker().Add(`["b"]`, jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestRegexArray(c *gc.C) {
a1 := []string{"a", "b", "c"}
a2 := []string{"a", "bbb", "ccc"}

checker := jc.NewMultiChecker().AddRegex("\\[[1-2]\\]", jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestRegexMap(c *gc.C) {
a1 := map[string]string{"a": "a", "b": "b", "c": "c"}
a2 := map[string]string{"a": "aaaa", "b": "bbbb", "c": "cccc"}

checker := jc.NewMultiChecker().AddRegex(`\[".*"\]`, jc.Ignore)
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestArrayArraysUnordered(c *gc.C) {
a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}}
a2 := [][]string{{"a", "b", "c"}, {}}

checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, []string{"e", "c", "d"})
c.Check(a1, checker, a2)
}

func (s *MultiCheckerSuite) TestArrayArraysUnorderedWithExpected(c *gc.C) {
a1 := [][]string{{"a", "b", "c"}, {"c", "d", "e"}}
a2 := [][]string{{"a", "b", "c"}, {"e", "c", "d"}}

checker := jc.NewMultiChecker().Add("[1]", jc.SameContents, jc.ExpectedValue)
c.Check(a1, checker, a2)
}

0 comments on commit e4eedbc

Please sign in to comment.