diff --git a/ast/compile.go b/ast/compile.go index c48504606f..dc5d4d44f9 100644 --- a/ast/compile.go +++ b/ast/compile.go @@ -854,7 +854,7 @@ func NewModuleTree(mods map[string]*Module) *ModuleTreeNode { c, ok := node.Children[x.Value] if !ok { var hide bool - if i == 1 && x.Value.Equal(SystemDocumentKey) { + if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 { hide = true } c = &ModuleTreeNode{ diff --git a/ast/map.go b/ast/map.go index 62a8adadd4..eb7f2d7e3c 100644 --- a/ast/map.go +++ b/ast/map.go @@ -110,5 +110,5 @@ func valueHash(v util.T) int { func valueEq(a, b util.T) bool { av := a.(Value) bv := b.(Value) - return av.Equal(bv) + return av.Compare(bv) == 0 } diff --git a/ast/policy.go b/ast/policy.go index af34688de6..3d4a5fbe4a 100644 --- a/ast/policy.go +++ b/ast/policy.go @@ -819,7 +819,7 @@ func (expr *Expr) IsEquality() bool { if len(terms) != 3 { return false } - return terms[0].Value.Equal(Equality.Name) + return terms[0].Value.Compare(Equality.Name) == 0 } // IsBuiltin returns true if this expression refers to a built-in function. diff --git a/ast/term.go b/ast/term.go index d766d8028c..15fa3a3973 100644 --- a/ast/term.go +++ b/ast/term.go @@ -73,7 +73,7 @@ func (loc *Location) String() string { // - Variables, References // - Array Comprehensions type Value interface { - Equal(other Value) bool // Equal returns true if this value equals the other value. + Compare(other Value) int // Compare returns <0, 0, or >0 if this Value is less than, equal to, or greater than other, respectively. Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found. Hash() int // Returns hash code of the value. IsGround() bool // IsGround returns true if this value is not a variable or contains no variables. @@ -269,7 +269,7 @@ func (term *Term) Equal(other *Term) bool { if term == other { return true } - return term.Value.Equal(other.Value) + return term.Value.Compare(other.Value) == 0 } // Hash returns the hash code of the Term's value. @@ -377,6 +377,12 @@ func (null Null) Equal(other Value) bool { } } +// Compare compares null to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (null Null) Compare(other Value) int { + return Compare(null, other) +} + // Find returns the current value or a not found error. func (null Null) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -417,6 +423,12 @@ func (bol Boolean) Equal(other Value) bool { } } +// Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (bol Boolean) Compare(other Value) int { + return Compare(bol, other) +} + // Find returns the current value or a not found error. func (bol Boolean) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -472,6 +484,12 @@ func (num Number) Equal(other Value) bool { } } +// Compare compares num to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (num Number) Compare(other Value) int { + return Compare(num, other) +} + // Find returns the current value or a not found error. func (num Number) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -531,6 +549,12 @@ func (str String) Equal(other Value) bool { } } +// Compare compares str to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (str String) Compare(other Value) int { + return Compare(str, other) +} + // Find returns the current value or a not found error. func (str String) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -573,6 +597,12 @@ func (v Var) Equal(other Value) bool { } } +// Compare compares v to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (v Var) Compare(other Value) int { + return Compare(v, other) +} + // Find returns the current value or a not found error. func (v Var) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -667,6 +697,12 @@ func (ref Ref) Equal(other Value) bool { return Compare(ref, other) == 0 } +// Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (ref Ref) Compare(other Value) int { + return Compare(ref, other) +} + // Find returns the current value or a not found error. func (ref Ref) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -786,6 +822,12 @@ func (arr Array) Equal(other Value) bool { return Compare(arr, other) == 0 } +// Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (arr Array) Compare(other Value) int { + return Compare(arr, other) +} + // Find returns the value at the index or an out-of-range error. func (arr Array) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -902,6 +944,12 @@ func (s *Set) Equal(v Value) bool { return Compare(s, v) == 0 } +// Compare compares s to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (s *Set) Compare(other Value) int { + return Compare(s, other) +} + // Find returns the current value or a not found error. func (s *Set) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -1019,6 +1067,12 @@ func (obj Object) Equal(other Value) bool { return Compare(obj, other) == 0 } +// Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (obj Object) Compare(other Value) int { + return Compare(obj, other) +} + // Find returns the value at the key or undefined. func (obj Object) Find(path Ref) (Value, error) { if len(path) == 0 { @@ -1195,6 +1249,12 @@ func (ac *ArrayComprehension) Equal(other Value) bool { return Compare(ac, other) == 0 } +// Compare compares ac to other, return <0, 0, or >0 if it is less than, equal to, +// or greater than other. +func (ac *ArrayComprehension) Compare(other Value) int { + return Compare(ac, other) +} + // Find returns the current value or a not found error. func (ac *ArrayComprehension) Find(path Ref) (Value, error) { if len(path) == 0 { diff --git a/ast/term_test.go b/ast/term_test.go index a8b1187d81..f21cb17477 100644 --- a/ast/term_test.go +++ b/ast/term_test.go @@ -41,7 +41,7 @@ func TestInterfaceToValue(t *testing.T) { return } - if !v.Equal(expected) { + if v.Compare(expected) != 0 { t.Errorf("Expected ast.Value to equal:\n%v\nBut got:\n%v", expected, v) } } @@ -146,7 +146,7 @@ func TestFind(t *testing.T) { if err != nil { t.Fatalf("Unexpected error occurred for %v: %v", tc.path, err) } - if !result.Equal(expected.Value) { + if result.Compare(expected.Value) != 0 { t.Fatalf("Expected value %v for %v but got: %v", expected, tc.path, result) } case error: diff --git a/ast/unify_test.go b/ast/unify_test.go index 9ee34ab94c..cb6accf8ef 100644 --- a/ast/unify_test.go +++ b/ast/unify_test.go @@ -53,7 +53,7 @@ func TestUnify(t *testing.T) { } terms := expr.Terms.([]*Term) - if !terms[0].Value.Equal(Equality.Name) { + if terms[0].Value.Compare(Equality.Name) != 0 { panic(terms) } diff --git a/server/authorizer/authorizer_test.go b/server/authorizer/authorizer_test.go index c4f968ef8f..205f1db239 100644 --- a/server/authorizer/authorizer_test.go +++ b/server/authorizer/authorizer_test.go @@ -178,7 +178,7 @@ func TestBasic(t *testing.T) { code, err := response.Find(ast.RefTerm(ast.StringTerm("code")).Value.(ast.Ref)) if err != nil { t.Fatalf("Missing code in response: %v", recorder) - } else if !code.Equal(ast.String(tc.expectedCode)) { + } else if code.Compare(ast.String(tc.expectedCode)) != 0 { t.Fatalf("Expected code %v but got: %v", tc.expectedCode, recorder) } diff --git a/server/server_test.go b/server/server_test.go index 719cfff3f0..e7f9788b35 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -622,7 +622,7 @@ func TestPoliciesPutV1ParseError(t *testing.T) { t.Fatalf("Expecfted to find name in errors but: %v", err) } - if !name.Equal(ast.String("test")) { + if name.Compare(ast.String("test")) != 0 { t.Fatalf("Expected name ot equal test but got: %v", name) } } @@ -658,7 +658,7 @@ q[x] { p[x] }`, t.Fatalf("Expecfted to find name in errors but: %v", err) } - if !name.Equal(ast.String("test")) { + if name.Compare(ast.String("test")) != 0 { t.Fatalf("Expected name ot equal test but got: %v", name) } } diff --git a/topdown/aggregates.go b/topdown/aggregates.go index f93c936bf9..e225acccd4 100644 --- a/topdown/aggregates.go +++ b/topdown/aggregates.go @@ -101,7 +101,7 @@ func builtinMin(a ast.Value) (ast.Value, error) { // The null term is considered to be less than any other term, // so in order for min of a set to make sense, we need to check // for it. - if min.Value.Equal(ast.Null{}) { + if min.Value.Compare(ast.Null{}) == 0 { return elem, nil } diff --git a/topdown/input.go b/topdown/input.go index ae5a950e66..20a2a51a8a 100644 --- a/topdown/input.go +++ b/topdown/input.go @@ -23,7 +23,7 @@ func MakeInput(pairs [][2]*ast.Term) (ast.Value, error) { } // Fast-path for the root case. - if len(pairs) == 1 && pairs[0][0].Value.Equal(ast.InputRootRef) { + if len(pairs) == 1 && pairs[0][0].Value.Compare(ast.InputRootRef) == 0 { return pairs[0][1].Value, nil } diff --git a/topdown/input_test.go b/topdown/input_test.go index 9e870ea308..84be84dc11 100644 --- a/topdown/input_test.go +++ b/topdown/input_test.go @@ -75,7 +75,7 @@ func TestMakeInput(t *testing.T) { continue } expected := ast.MustParseTerm(e) - if !expected.Value.Equal(input) { + if expected.Value.Compare(input) != 0 { t.Errorf("%v (#%d): Expected input to equal %v but got: %v", tc.note, i+1, expected, input) } } diff --git a/topdown/topdown.go b/topdown/topdown.go index 2293c702d4..0b6dddacff 100644 --- a/topdown/topdown.go +++ b/topdown/topdown.go @@ -924,7 +924,7 @@ func evalExpr(t *Topdown, iter Iterator) error { return err } } - if !v.Equal(ast.Boolean(false)) { + if v.Compare(ast.Boolean(false)) != 0 { if v.IsGround() { return iter(t) } @@ -1370,7 +1370,7 @@ func evalRefRuleCompleteDocSingle(t *Topdown, rule *ast.Rule, redo bool, last as // If document is already defined, check for conflict. if last != nil { - if !last.Equal(result) { + if last.Compare(result) != 0 { return completeDocConflictErr(t.currentLocation(rule)) } } else { @@ -1560,7 +1560,7 @@ func evalRefRulePartialObjectDocFull(t *Topdown, ref ast.Ref, rules []*ast.Rule, return fmt.Errorf("unbound variable: %v", value) } - if exist := keys.Get(key); exist != nil && !exist.Equal(value) { + if exist := keys.Get(key); exist != nil && exist.Compare(value) != 0 { return objectDocKeyConflictErr(t.currentLocation(rule)) } @@ -1784,7 +1784,7 @@ func evalRefRuleResultRecObject(t *Topdown, obj ast.Object, ref, path ast.Ref, i return err } } - if x.Equal(k) { + if x.Compare(k) == 0 { match = idx break } @@ -1949,7 +1949,7 @@ func evalTermsIndexed(t *Topdown, iter Iterator, index storage.Index, nonIndexed // different binding for the same variable. This can arise if output // variables in references on either side intersect (e.g., a[i] = g[i][j]). skip := bindings.Iter(func(k, v ast.Value) bool { - if o := t.Binding(k); o != nil && !o.Equal(v) { + if o := t.Binding(k); o != nil && o.Compare(v) != 0 { return true } prev = t.Bind(k, v, prev) diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 1200003762..b079929586 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -220,14 +220,14 @@ func TestPlugValue(t *testing.T) { r1 := PlugValue(a, t1.Binding) - if !expected.Equal(r1) { + if expected.Compare(r1) != 0 { t.Errorf("Expected %v but got %v", expected, r1) return } r2 := PlugValue(a, t2.Binding) - if !expected.Equal(r2) { + if expected.Compare(r2) != 0 { t.Errorf("Expected %v but got %v", expected, r2) } @@ -241,7 +241,7 @@ func TestPlugValue(t *testing.T) { r3 := PlugValue(n, t3.Binding) - if !expected.Equal(r3) { + if expected.Compare(r3) != 0 { t.Errorf("Expected %v but got: %v", expected, r3) } }