Skip to content

Commit

Permalink
Use func ops for all data structures
Browse files Browse the repository at this point in the history
  • Loading branch information
zyedidia committed Dec 19, 2021
1 parent 29240d8 commit 45189a2
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 61 deletions.
51 changes: 27 additions & 24 deletions avl/avl.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,37 @@ import (
"github.com/zyedidia/generic/iter"
)

type KV[K g.Lesser[K], V any] struct {
type KV[K any, V any] struct {
Key K
Val V
}

// Tree implements an AVL tree.
type Tree[K g.Lesser[K], V any] struct {
type Tree[K any, V any] struct {
root *node[K, V]
less g.Lesser[K]
}

// New returns an empty AVL tree.
func New[K g.Lesser[K], V any]() *Tree[K, V] {
return &Tree[K, V]{}
func New[K any, V any](less g.Lesser[K]) *Tree[K, V] {
return &Tree[K, V]{
less: less,
}
}

// Put associates 'key' with 'value'.
func (t *Tree[K, V]) Put(key K, value V) {
t.root = t.root.add(key, value)
t.root = t.root.add(key, value, t.less)
}

// Remove removes the value associated with 'key'.
func (t *Tree[K, V]) Remove(key K) {
t.root = t.root.remove(key)
t.root = t.root.remove(key, t.less)
}

// Get returns the value associated with 'key'.
func (t *Tree[K, V]) Get(key K) (V, bool) {
n := t.root.search(key)
n := t.root.search(key, t.less)
if n == nil {
var v V
return v, false
Expand All @@ -60,7 +63,7 @@ func (t *Tree[K, V]) Size() int {
return t.root.size()
}

type node[K g.Lesser[K], V any] struct {
type node[K any, V any] struct {
key K
value V

Expand All @@ -69,7 +72,7 @@ type node[K g.Lesser[K], V any] struct {
right *node[K, V]
}

func (n *node[K, V]) add(key K, value V) *node[K, V] {
func (n *node[K, V]) add(key K, value V, less g.Lesser[K]) *node[K, V] {
if n == nil {
return &node[K, V]{
key: key,
Expand All @@ -80,30 +83,30 @@ func (n *node[K, V]) add(key K, value V) *node[K, V] {
}
}

if g.Compare(key, n.key) < 0 {
n.left = n.left.add(key, value)
} else if g.Compare(key, n.key) > 0 {
n.right = n.right.add(key, value)
if g.Compare(key, n.key, less) < 0 {
n.left = n.left.add(key, value, less)
} else if g.Compare(key, n.key, less) > 0 {
n.right = n.right.add(key, value, less)
} else {
n.value = value
}
return n.rebalanceTree()
}

func (n *node[K, V]) remove(key K) *node[K, V] {
func (n *node[K, V]) remove(key K, less g.Lesser[K]) *node[K, V] {
if n == nil {
return nil
}
if g.Compare(key, n.key) < 0 {
n.left = n.left.remove(key)
} else if g.Compare(key, n.key) > 0 {
n.right = n.right.remove(key)
if g.Compare(key, n.key, less) < 0 {
n.left = n.left.remove(key, less)
} else if g.Compare(key, n.key, less) > 0 {
n.right = n.right.remove(key, less)
} else {
if n.left != nil && n.right != nil {
rightMinNode := n.right.findSmallest()
n.key = rightMinNode.key
n.value = rightMinNode.value
n.right = n.right.remove(rightMinNode.key)
n.right = n.right.remove(rightMinNode.key, less)
} else if n.left != nil {
n = n.left
} else if n.right != nil {
Expand All @@ -117,14 +120,14 @@ func (n *node[K, V]) remove(key K) *node[K, V] {
return n.rebalanceTree()
}

func (n *node[K, V]) search(key K) *node[K, V] {
func (n *node[K, V]) search(key K, less g.Lesser[K]) *node[K, V] {
if n == nil {
return nil
}
if g.Compare(key, n.key) < 0 {
return n.left.search(key)
} else if g.Compare(key, n.key) > 0 {
return n.right.search(key)
if g.Compare(key, n.key, less) < 0 {
return n.left.search(key, less)
} else if g.Compare(key, n.key, less) > 0 {
return n.right.search(key, less)
} else {
return n
}
Expand Down
14 changes: 7 additions & 7 deletions avl/avl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/zyedidia/generic/avl"
)

func checkeq[K g.Lesser[K], V comparable](cm *avl.Tree[K, V], get func(k K) (V, bool), t *testing.T) {
func checkeq[K any, V comparable](cm *avl.Tree[K, V], get func(k K) (V, bool), t *testing.T) {
cm.Iter().For(func(kv avl.KV[K, V]) {
if ov, ok := get(kv.Key); !ok {
t.Fatalf("key %v should exist", kv.Key)
Expand All @@ -21,7 +21,7 @@ func checkeq[K g.Lesser[K], V comparable](cm *avl.Tree[K, V], get func(k K) (V,

func TestCrossCheck(t *testing.T) {
stdm := make(map[int]int)
tree := avl.New[g.Int, int]()
tree := avl.New[int, int](g.Less[int])

const nops = 1000

Expand All @@ -33,34 +33,34 @@ func TestCrossCheck(t *testing.T) {
switch op {
case 0:
stdm[key] = val
tree.Put(g.Int(key), val)
tree.Put(key, val)
case 1:
var del int
for k := range stdm {
del = k
break
}
delete(stdm, del)
tree.Remove(g.Int(del))
tree.Remove(del)
}

checkeq(tree, func(k g.Int) (int, bool) {
checkeq(tree, func(k int) (int, bool) {
v, ok := stdm[int(k)]
return v, ok
}, t)
}
}

func Example() {
tree := avl.New[g.Int, g.String]()
tree := avl.New[int, string](g.Less[int])

tree.Put(42, "foo")
tree.Put(-10, "bar")
tree.Put(0, "baz")
tree.Put(10, "quux")
tree.Remove(10)

tree.Iter().For(func(kv avl.KV[g.Int, g.String]) {
tree.Iter().For(func(kv avl.KV[int, string]) {
fmt.Println(kv.Key, kv.Val)
})

Expand Down
23 changes: 13 additions & 10 deletions btree/btree.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

const maxChildren = 64 // must be even and > 2

type KV[K g.Lesser[K], V any] struct {
type KV[K any, V any] struct {
Key K
Val V
}
Expand All @@ -21,28 +21,31 @@ type KV[K g.Lesser[K], V any] struct {
// https://algs4.cs.princeton.edu/62btree/BTree.java.html.

// Tree implements a B-tree.
type Tree[K g.Lesser[K], V any] struct {
type Tree[K any, V any] struct {
root *node[K, V]
height int
n int

less g.Lesser[K]
}

type node[K g.Lesser[K], V any] struct {
type node[K any, V any] struct {
m int
children [maxChildren]entry[K, V]
}

type entry[K g.Lesser[K], V any] struct {
type entry[K any, V any] struct {
key K
val V
valid bool
next *node[K, V]
}

// New returns an empty B-tree.
func New[K g.Lesser[K], V any]() *Tree[K, V] {
func New[K any, V any](less g.Lesser[K]) *Tree[K, V] {
return &Tree[K, V]{
root: &node[K, V]{},
less: less,
}
}

Expand All @@ -62,14 +65,14 @@ func (t *Tree[K, V]) search(x *node[K, V], key K, height int) (V, bool) {
if height == 0 {
// leaf node
for j := 0; j < x.m; j++ {
if g.Compare(key, children[j].key) == 0 {
if g.Compare(key, children[j].key, t.less) == 0 {
return children[j].val, children[j].valid
}
}
} else {
// internal node
for j := 0; j < x.m; j++ {
if x.m == j+1 || g.Compare(key, children[j+1].key) < 0 {
if x.m == j+1 || g.Compare(key, children[j+1].key, t.less) < 0 {
return t.search(children[j].next, key, height-1)
}
}
Expand Down Expand Up @@ -124,18 +127,18 @@ func (t *Tree[K, V]) insert(h *node[K, V], key K, val V, height int, valid bool)
if height == 0 {
// leaf node
for j = 0; j < h.m; j++ {
if g.Compare(key, h.children[j].key) == 0 {
if g.Compare(key, h.children[j].key, t.less) == 0 {
h.children[j].val = val
h.children[j].valid = valid
return nil
} else if g.Compare(key, h.children[j].key) < 0 {
} else if g.Compare(key, h.children[j].key, t.less) < 0 {
break
}
}
} else {
// internal node
for j = 0; j < h.m; j++ {
if (j+1 == h.m) || g.Compare(key, h.children[j+1].key) < 0 {
if (j+1 == h.m) || g.Compare(key, h.children[j+1].key, t.less) < 0 {
u := t.insert(h.children[j].next, key, val, height-1, valid)
if u == nil {
return nil
Expand Down
14 changes: 7 additions & 7 deletions btree/btree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/zyedidia/generic/btree"
)

func checkeq[K g.Lesser[K], V comparable](cm *btree.Tree[K, V], get func(k K) (V, bool), t *testing.T) {
func checkeq[K any, V comparable](cm *btree.Tree[K, V], get func(k K) (V, bool), t *testing.T) {
cm.Iter().For(func(kv btree.KV[K, V]) {
if ov, ok := get(kv.Key); !ok {
t.Fatalf("key %v should exist", kv.Key)
Expand All @@ -21,7 +21,7 @@ func checkeq[K g.Lesser[K], V comparable](cm *btree.Tree[K, V], get func(k K) (V

func TestCrossCheck(t *testing.T) {
stdm := make(map[int]int)
tree := btree.New[g.Int, int]()
tree := btree.New[int, int](g.Less[int])

const nops = 1000

Expand All @@ -33,32 +33,32 @@ func TestCrossCheck(t *testing.T) {
switch op {
case 0:
stdm[key] = val
tree.Put(g.Int(key), val)
tree.Put(key, val)
case 1:
var del int
for k := range stdm {
del = k
break
}
delete(stdm, del)
tree.Remove(g.Int(del))
tree.Remove(del)
}

checkeq(tree, func(k g.Int) (int, bool) {
checkeq(tree, func(k int) (int, bool) {
v, ok := stdm[int(k)]
return v, ok
}, t)
}
}

func Example() {
tree := btree.New[g.Int, g.String]()
tree := btree.New[int, string](g.Less[int])

tree.Put(42, "foo")
tree.Put(-10, "bar")
tree.Put(0, "baz")

tree.Iter().For(func(kv btree.KV[g.Int, g.String]) {
tree.Iter().For(func(kv btree.KV[int, string]) {
fmt.Println(kv.Key, kv.Val)
})

Expand Down
5 changes: 2 additions & 3 deletions cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@ package cache_test
import (
"fmt"

g "github.com/zyedidia/generic"
"github.com/zyedidia/generic/cache"
)

func Example() {
c := cache.New[g.Int, g.Int](2)
c := cache.New[int, int](2)

c.Put(42, 42)
c.Put(10, 10)
c.Get(42)
c.Put(0, 0) // evicts 10

c.Iter().For(func(kv cache.KV[g.Int, g.Int]) {
c.Iter().For(func(kv cache.KV[int, int]) {
fmt.Println(kv.Key)
})
// Output:
Expand Down
7 changes: 3 additions & 4 deletions hashset/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@
package hashset

import (
g "github.com/zyedidia/generic"
"github.com/zyedidia/generic/hashmap"
"github.com/zyedidia/generic/iter"
)

// Set implements a hashset, using the hashmap as the underlying storage.
type Set[K g.Hashable[K]] struct {
type Set[K any] struct {
m *hashmap.Map[K, struct{}]
}

// New returns an empty hashset.
func New[K g.Hashable[K]](capacity uint64) *Set[K] {
func New[K any](capacity uint64, ops hashmap.Ops[K]) *Set[K] {
return &Set[K]{
m: hashmap.NewMap[K, struct{}](capacity),
m: hashmap.NewMap[K, struct{}](capacity, ops),
}
}

Expand Down
Loading

0 comments on commit 45189a2

Please sign in to comment.