Skip to content

Commit

Permalink
Refactor ast.Value to use Compare instead of Equal
Browse files Browse the repository at this point in the history
  • Loading branch information
mmussomele authored and tsandall committed Jun 27, 2017
1 parent 998273c commit 3261995
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ func NewModuleTree(mods map[string]*Module) *ModuleTreeNode {
c, ok := node.Children[x.Value]
if !ok {
var hide bool
if i == 1 && x.Value.Equal(SystemDocumentKey) {
if i == 1 && x.Value.Compare(SystemDocumentKey) == 0 {
hide = true
}
c = &ModuleTreeNode{
Expand Down
2 changes: 1 addition & 1 deletion ast/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,5 @@ func valueHash(v util.T) int {
func valueEq(a, b util.T) bool {
av := a.(Value)
bv := b.(Value)
return av.Equal(bv)
return av.Compare(bv) == 0
}
2 changes: 1 addition & 1 deletion ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ func (expr *Expr) IsEquality() bool {
if len(terms) != 3 {
return false
}
return terms[0].Value.Equal(Equality.Name)
return terms[0].Value.Compare(Equality.Name) == 0
}

// IsBuiltin returns true if this expression refers to a built-in function.
Expand Down
64 changes: 62 additions & 2 deletions ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (loc *Location) String() string {
// - Variables, References
// - Array Comprehensions
type Value interface {
Equal(other Value) bool // Equal returns true if this value equals the other value.
Compare(other Value) int // Compare returns <0, 0, or >0 if this Value is less than, equal to, or greater than other, respectively.
Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found.
Hash() int // Returns hash code of the value.
IsGround() bool // IsGround returns true if this value is not a variable or contains no variables.
Expand Down Expand Up @@ -269,7 +269,7 @@ func (term *Term) Equal(other *Term) bool {
if term == other {
return true
}
return term.Value.Equal(other.Value)
return term.Value.Compare(other.Value) == 0
}

// Hash returns the hash code of the Term's value.
Expand Down Expand Up @@ -377,6 +377,12 @@ func (null Null) Equal(other Value) bool {
}
}

// Compare compares null to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (null Null) Compare(other Value) int {
return Compare(null, other)
}

// Find returns the current value or a not found error.
func (null Null) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -417,6 +423,12 @@ func (bol Boolean) Equal(other Value) bool {
}
}

// Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (bol Boolean) Compare(other Value) int {
return Compare(bol, other)
}

// Find returns the current value or a not found error.
func (bol Boolean) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -472,6 +484,12 @@ func (num Number) Equal(other Value) bool {
}
}

// Compare compares num to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (num Number) Compare(other Value) int {
return Compare(num, other)
}

// Find returns the current value or a not found error.
func (num Number) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -531,6 +549,12 @@ func (str String) Equal(other Value) bool {
}
}

// Compare compares str to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (str String) Compare(other Value) int {
return Compare(str, other)
}

// Find returns the current value or a not found error.
func (str String) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -573,6 +597,12 @@ func (v Var) Equal(other Value) bool {
}
}

// Compare compares v to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (v Var) Compare(other Value) int {
return Compare(v, other)
}

// Find returns the current value or a not found error.
func (v Var) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -667,6 +697,12 @@ func (ref Ref) Equal(other Value) bool {
return Compare(ref, other) == 0
}

// Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (ref Ref) Compare(other Value) int {
return Compare(ref, other)
}

// Find returns the current value or a not found error.
func (ref Ref) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -786,6 +822,12 @@ func (arr Array) Equal(other Value) bool {
return Compare(arr, other) == 0
}

// Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (arr Array) Compare(other Value) int {
return Compare(arr, other)
}

// Find returns the value at the index or an out-of-range error.
func (arr Array) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -902,6 +944,12 @@ func (s *Set) Equal(v Value) bool {
return Compare(s, v) == 0
}

// Compare compares s to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (s *Set) Compare(other Value) int {
return Compare(s, other)
}

// Find returns the current value or a not found error.
func (s *Set) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -1019,6 +1067,12 @@ func (obj Object) Equal(other Value) bool {
return Compare(obj, other) == 0
}

// Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (obj Object) Compare(other Value) int {
return Compare(obj, other)
}

// Find returns the value at the key or undefined.
func (obj Object) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down Expand Up @@ -1195,6 +1249,12 @@ func (ac *ArrayComprehension) Equal(other Value) bool {
return Compare(ac, other) == 0
}

// Compare compares ac to other, return <0, 0, or >0 if it is less than, equal to,
// or greater than other.
func (ac *ArrayComprehension) Compare(other Value) int {
return Compare(ac, other)
}

// Find returns the current value or a not found error.
func (ac *ArrayComprehension) Find(path Ref) (Value, error) {
if len(path) == 0 {
Expand Down
4 changes: 2 additions & 2 deletions ast/term_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestInterfaceToValue(t *testing.T) {
return
}

if !v.Equal(expected) {
if v.Compare(expected) != 0 {
t.Errorf("Expected ast.Value to equal:\n%v\nBut got:\n%v", expected, v)
}
}
Expand Down Expand Up @@ -146,7 +146,7 @@ func TestFind(t *testing.T) {
if err != nil {
t.Fatalf("Unexpected error occurred for %v: %v", tc.path, err)
}
if !result.Equal(expected.Value) {
if result.Compare(expected.Value) != 0 {
t.Fatalf("Expected value %v for %v but got: %v", expected, tc.path, result)
}
case error:
Expand Down
2 changes: 1 addition & 1 deletion ast/unify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestUnify(t *testing.T) {
}

terms := expr.Terms.([]*Term)
if !terms[0].Value.Equal(Equality.Name) {
if terms[0].Value.Compare(Equality.Name) != 0 {
panic(terms)
}

Expand Down
2 changes: 1 addition & 1 deletion server/authorizer/authorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestBasic(t *testing.T) {
code, err := response.Find(ast.RefTerm(ast.StringTerm("code")).Value.(ast.Ref))
if err != nil {
t.Fatalf("Missing code in response: %v", recorder)
} else if !code.Equal(ast.String(tc.expectedCode)) {
} else if code.Compare(ast.String(tc.expectedCode)) != 0 {
t.Fatalf("Expected code %v but got: %v", tc.expectedCode, recorder)
}

Expand Down
4 changes: 2 additions & 2 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ func TestPoliciesPutV1ParseError(t *testing.T) {
t.Fatalf("Expecfted to find name in errors but: %v", err)
}

if !name.Equal(ast.String("test")) {
if name.Compare(ast.String("test")) != 0 {
t.Fatalf("Expected name ot equal test but got: %v", name)
}
}
Expand Down Expand Up @@ -658,7 +658,7 @@ q[x] { p[x] }`,
t.Fatalf("Expecfted to find name in errors but: %v", err)
}

if !name.Equal(ast.String("test")) {
if name.Compare(ast.String("test")) != 0 {
t.Fatalf("Expected name ot equal test but got: %v", name)
}
}
Expand Down
2 changes: 1 addition & 1 deletion topdown/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func builtinMin(a ast.Value) (ast.Value, error) {
// The null term is considered to be less than any other term,
// so in order for min of a set to make sense, we need to check
// for it.
if min.Value.Equal(ast.Null{}) {
if min.Value.Compare(ast.Null{}) == 0 {
return elem, nil
}

Expand Down
2 changes: 1 addition & 1 deletion topdown/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func MakeInput(pairs [][2]*ast.Term) (ast.Value, error) {
}

// Fast-path for the root case.
if len(pairs) == 1 && pairs[0][0].Value.Equal(ast.InputRootRef) {
if len(pairs) == 1 && pairs[0][0].Value.Compare(ast.InputRootRef) == 0 {
return pairs[0][1].Value, nil
}

Expand Down
2 changes: 1 addition & 1 deletion topdown/input_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestMakeInput(t *testing.T) {
continue
}
expected := ast.MustParseTerm(e)
if !expected.Value.Equal(input) {
if expected.Value.Compare(input) != 0 {
t.Errorf("%v (#%d): Expected input to equal %v but got: %v", tc.note, i+1, expected, input)
}
}
Expand Down
10 changes: 5 additions & 5 deletions topdown/topdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ func evalExpr(t *Topdown, iter Iterator) error {
return err
}
}
if !v.Equal(ast.Boolean(false)) {
if v.Compare(ast.Boolean(false)) != 0 {
if v.IsGround() {
return iter(t)
}
Expand Down Expand Up @@ -1370,7 +1370,7 @@ func evalRefRuleCompleteDocSingle(t *Topdown, rule *ast.Rule, redo bool, last as

// If document is already defined, check for conflict.
if last != nil {
if !last.Equal(result) {
if last.Compare(result) != 0 {
return completeDocConflictErr(t.currentLocation(rule))
}
} else {
Expand Down Expand Up @@ -1560,7 +1560,7 @@ func evalRefRulePartialObjectDocFull(t *Topdown, ref ast.Ref, rules []*ast.Rule,
return fmt.Errorf("unbound variable: %v", value)
}

if exist := keys.Get(key); exist != nil && !exist.Equal(value) {
if exist := keys.Get(key); exist != nil && exist.Compare(value) != 0 {
return objectDocKeyConflictErr(t.currentLocation(rule))
}

Expand Down Expand Up @@ -1784,7 +1784,7 @@ func evalRefRuleResultRecObject(t *Topdown, obj ast.Object, ref, path ast.Ref, i
return err
}
}
if x.Equal(k) {
if x.Compare(k) == 0 {
match = idx
break
}
Expand Down Expand Up @@ -1949,7 +1949,7 @@ func evalTermsIndexed(t *Topdown, iter Iterator, index storage.Index, nonIndexed
// different binding for the same variable. This can arise if output
// variables in references on either side intersect (e.g., a[i] = g[i][j]).
skip := bindings.Iter(func(k, v ast.Value) bool {
if o := t.Binding(k); o != nil && !o.Equal(v) {
if o := t.Binding(k); o != nil && o.Compare(v) != 0 {
return true
}
prev = t.Bind(k, v, prev)
Expand Down
6 changes: 3 additions & 3 deletions topdown/topdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,14 @@ func TestPlugValue(t *testing.T) {

r1 := PlugValue(a, t1.Binding)

if !expected.Equal(r1) {
if expected.Compare(r1) != 0 {
t.Errorf("Expected %v but got %v", expected, r1)
return
}

r2 := PlugValue(a, t2.Binding)

if !expected.Equal(r2) {
if expected.Compare(r2) != 0 {
t.Errorf("Expected %v but got %v", expected, r2)
}

Expand All @@ -241,7 +241,7 @@ func TestPlugValue(t *testing.T) {

r3 := PlugValue(n, t3.Binding)

if !expected.Equal(r3) {
if expected.Compare(r3) != 0 {
t.Errorf("Expected %v but got: %v", expected, r3)
}
}
Expand Down

0 comments on commit 3261995

Please sign in to comment.