Skip to content

Commit

Permalink
topdown/copypropagation: unionfind now uses ast.Value
Browse files Browse the repository at this point in the history
The underlying code will work in the same way, it doesn't affect the
algorithm. We just were limiting to ast.Var before because of the
initial use case only wanting to group variables.

Signed-off-by: Patrick East <east.patrick@gmail.com>
  • Loading branch information
patrick-east authored and tsandall committed Apr 24, 2020
1 parent 0c9f5aa commit a36e124
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 178 deletions.
14 changes: 7 additions & 7 deletions topdown/copypropagation/copypropagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ func (p *CopyPropagator) Apply(query ast.Body) (result ast.Body) {
if v, ok := x[0].Value.(ast.Var); ok {
if root, ok := uf.Find(v); ok {
root.constant = nil
headvars.Add(root.key)
headvars.Add(root.key.(ast.Var))
} else {
headvars.Add(v)
}
}
return false
})

bindings := map[ast.Var]*binding{}
bindings := map[ast.Value]*binding{}

for _, expr := range query {

Expand Down Expand Up @@ -218,7 +218,7 @@ func (t bindingPlugTransform) plugBindingsVar(pctx *plugContext, v ast.Var) (res
func (t bindingPlugTransform) plugBindingsRef(pctx *plugContext, v ast.Ref) ast.Ref {

// Apply union-find to remove redundant variables from input.
if root, ok := pctx.uf.Find(v[0].Value.(ast.Var)); ok {
if root, ok := pctx.uf.Find(v[0].Value); ok {
v[0].Value = root.Value()
}

Expand Down Expand Up @@ -289,14 +289,14 @@ func (p *CopyPropagator) updateBindingsEqAsymmetric(a, b *ast.Term) (ast.Var, as
}

type plugContext struct {
bindings map[ast.Var]*binding
bindings map[ast.Value]*binding
uf *unionFind
headvars ast.VarSet
negated bool
}

type binding struct {
k ast.Var
k ast.Value
v ast.Value
}

Expand Down Expand Up @@ -327,7 +327,7 @@ func (b *binding) containedIn(query ast.Body) bool {
return stop
}

func sortbindings(bindings map[ast.Var]*binding) []*binding {
func sortbindings(bindings map[ast.Value]*binding) []*binding {
sorted := make([]*binding, 0, len(bindings))
for _, b := range bindings {
sorted = append(sorted, b)
Expand All @@ -346,7 +346,7 @@ func sortbindings(bindings map[ast.Var]*binding) []*binding {
// false.
func makeDisjointSets(livevars ast.VarSet, query ast.Body) (*unionFind, bool) {
uf := newUnionFind(func(r1, r2 *unionFindRoot) (*unionFindRoot, *unionFindRoot) {
if livevars.Contains(r1.key) {
if v, ok := r1.key.(ast.Var); ok && livevars.Contains(v) {
return r1, r2
}
return r2, r1
Expand Down
81 changes: 62 additions & 19 deletions topdown/copypropagation/unionfind.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,62 @@

package copypropagation

import "github.com/open-policy-agent/opa/ast"
import (
"fmt"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/util"
)

type rankFunc func(*unionFindRoot, *unionFindRoot) (*unionFindRoot, *unionFindRoot)

type unionFind struct {
roots map[ast.Var]*unionFindRoot
parents map[ast.Var]ast.Var
roots *util.HashMap
parents *ast.ValueMap
rank rankFunc
}

func newUnionFind(rank rankFunc) *unionFind {
return &unionFind{
roots: map[ast.Var]*unionFindRoot{},
parents: map[ast.Var]ast.Var{},
roots: util.NewHashMap(func(a util.T, b util.T) bool {
return a.(ast.Value).Compare(b.(ast.Value)) == 0
}, func(v util.T) int {
return v.(ast.Value).Hash()
}),
parents: ast.NewValueMap(),
rank: rank,
}
}

func (uf *unionFind) MakeSet(v ast.Var) *unionFindRoot {
func (uf *unionFind) MakeSet(v ast.Value) *unionFindRoot {

root, ok := uf.Find(v)
if ok {
return root
}

root = newUnionFindRoot(v)
uf.parents[v] = v
uf.roots[v] = root
return uf.roots[v]
uf.parents.Put(v, v)
uf.roots.Put(v, root)
return root
}

func (uf *unionFind) Find(v ast.Var) (*unionFindRoot, bool) {
func (uf *unionFind) Find(v ast.Value) (*unionFindRoot, bool) {

parent, ok := uf.parents[v]
if !ok {
parent := uf.parents.Get(v)
if parent == nil {
return nil, false
}

if parent == v {
return uf.roots[v], true
if parent.Compare(v) == 0 {
r, ok := uf.roots.Get(v)
return r.(*unionFindRoot), ok
}

return uf.Find(parent)
}

func (uf *unionFind) Merge(a, b ast.Var) (*unionFindRoot, bool) {
func (uf *unionFind) Merge(a, b ast.Value) (*unionFindRoot, bool) {

r1 := uf.MakeSet(a)
r2 := uf.MakeSet(b)
Expand All @@ -57,8 +68,8 @@ func (uf *unionFind) Merge(a, b ast.Var) (*unionFindRoot, bool) {

r1, r2 = uf.rank(r1, r2)

uf.parents[r2.key] = r1.key
delete(uf.roots, r2.key)
uf.parents.Put(r2.key, r1.key)
uf.roots.Delete(r2.key)

// Sets can have at most one constant value associated with them. When
// unioning, we must preserve this invariant. If a set has two constants,
Expand All @@ -73,12 +84,40 @@ func (uf *unionFind) Merge(a, b ast.Var) (*unionFindRoot, bool) {
return r1, true
}

func (uf *unionFind) String() string {
o := struct {
Roots map[string]interface{}
Parents map[string]ast.Value
}{
map[string]interface{}{},
map[string]ast.Value{},
}

uf.roots.Iter(func(k util.T, v util.T) bool {
o.Roots[k.(ast.Value).String()] = struct {
Constant *ast.Term
Key ast.Value
}{
v.(*unionFindRoot).constant,
v.(*unionFindRoot).key,
}
return true
})

uf.parents.Iter(func(k ast.Value, v ast.Value) bool {
o.Parents[k.String()] = v
return true
})

return string(util.MustMarshalJSON(o))
}

type unionFindRoot struct {
key ast.Var
key ast.Value
constant *ast.Term
}

func newUnionFindRoot(key ast.Var) *unionFindRoot {
func newUnionFindRoot(key ast.Value) *unionFindRoot {
return &unionFindRoot{
key: key,
}
Expand All @@ -90,3 +129,7 @@ func (r *unionFindRoot) Value() ast.Value {
}
return r.key
}

func (r *unionFindRoot) String() string {
return fmt.Sprintf("{key: %s, constant: %s", r.key, r.constant)
}
Loading

0 comments on commit a36e124

Please sign in to comment.