Skip to content

Commit 6624270

Browse files
joshua-kimStephenButtolphdhrubabasudboehm-avalabsabi87
authored
Add Heap Set (#2136)
Signed-off-by: Joshua Kim <20001595+joshua-kim@users.noreply.github.com> Signed-off-by: Stephen Buttolph <stephen@avalabs.org> Co-authored-by: Stephen Buttolph <stephen@avalabs.org> Co-authored-by: Dhruba Basu <7675102+dhrubabasu@users.noreply.github.com> Co-authored-by: dboehm-avalabs <david.boehm@avalabs.org> Co-authored-by: David Boehm <91908103+dboehm-avalabs@users.noreply.github.com> Co-authored-by: Alberto Benegiamo <alberto.benegiamo@gmail.com>
1 parent 804f45b commit 6624270

File tree

4 files changed

+179
-199
lines changed

4 files changed

+179
-199
lines changed

utils/heap/set.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
2+
// See the file LICENSE for licensing terms.
3+
4+
package heap
5+
6+
// NewSet returns a heap without duplicates ordered by its values
7+
func NewSet[T comparable](less func(a, b T) bool) Set[T] {
8+
return Set[T]{
9+
set: NewMap[T, T](less),
10+
}
11+
}
12+
13+
type Set[T comparable] struct {
14+
set Map[T, T]
15+
}
16+
17+
// Push returns if the entry was added
18+
func (s Set[T]) Push(t T) bool {
19+
_, hadValue := s.set.Push(t, t)
20+
return !hadValue
21+
}
22+
23+
func (s Set[T]) Pop() (T, bool) {
24+
pop, _, ok := s.set.Pop()
25+
return pop, ok
26+
}
27+
28+
func (s Set[T]) Peek() (T, bool) {
29+
peek, _, ok := s.set.Peek()
30+
return peek, ok
31+
}
32+
33+
func (s Set[T]) Len() int {
34+
return s.set.Len()
35+
}
36+
37+
func (s Set[T]) Remove(t T) bool {
38+
_, existed := s.set.Remove(t)
39+
return existed
40+
}
41+
42+
func (s Set[T]) Fix(t T) {
43+
s.set.Fix(t)
44+
}
45+
46+
func (s Set[T]) Contains(t T) bool {
47+
return s.set.Contains(t)
48+
}

utils/heap/set_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (C) 2019-2023, Ava Labs, Inc. All rights reserved.
2+
// See the file LICENSE for licensing terms.
3+
4+
package heap
5+
6+
import (
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestSet(t *testing.T) {
13+
tests := []struct {
14+
name string
15+
setup func(h Set[int])
16+
expected []int
17+
}{
18+
{
19+
name: "only push",
20+
setup: func(h Set[int]) {
21+
h.Push(1)
22+
h.Push(2)
23+
h.Push(3)
24+
},
25+
expected: []int{1, 2, 3},
26+
},
27+
{
28+
name: "out of order pushes",
29+
setup: func(h Set[int]) {
30+
h.Push(1)
31+
h.Push(5)
32+
h.Push(2)
33+
h.Push(4)
34+
h.Push(3)
35+
},
36+
expected: []int{1, 2, 3, 4, 5},
37+
},
38+
{
39+
name: "push and pop",
40+
setup: func(h Set[int]) {
41+
h.Push(1)
42+
h.Push(5)
43+
h.Push(2)
44+
h.Push(4)
45+
h.Push(3)
46+
h.Pop()
47+
h.Pop()
48+
h.Pop()
49+
},
50+
expected: []int{4, 5},
51+
},
52+
}
53+
54+
for _, tt := range tests {
55+
t.Run(tt.name, func(t *testing.T) {
56+
require := require.New(t)
57+
58+
h := NewSet[int](func(a, b int) bool {
59+
return a < b
60+
})
61+
62+
tt.setup(h)
63+
64+
require.Equal(len(tt.expected), h.Len())
65+
for _, expected := range tt.expected {
66+
got, ok := h.Pop()
67+
require.True(ok)
68+
require.Equal(expected, got)
69+
}
70+
})
71+
}
72+
}

x/sync/workheap.go

Lines changed: 34 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,38 @@ package sync
55

66
import (
77
"bytes"
8-
"container/heap"
98

9+
"github.com/ava-labs/avalanchego/utils/heap"
1010
"github.com/ava-labs/avalanchego/utils/math"
1111
"github.com/ava-labs/avalanchego/utils/maybe"
1212

1313
"github.com/google/btree"
1414
)
1515

16-
var _ heap.Interface = (*innerHeap)(nil)
17-
18-
type heapItem struct {
19-
workItem *workItem
20-
heapIndex int
21-
}
22-
23-
type innerHeap []*heapItem
24-
2516
// A priority queue of syncWorkItems.
2617
// Note that work item ranges never overlap.
2718
// Supports range merging and priority updating.
2819
// Not safe for concurrent use.
2920
type workHeap struct {
3021
// Max heap of items by priority.
3122
// i.e. heap.Pop returns highest priority item.
32-
innerHeap innerHeap
23+
innerHeap heap.Set[*workItem]
3324
// The heap items sorted by range start.
3425
// A Nothing start is considered to be the smallest.
35-
sortedItems *btree.BTreeG[*heapItem]
26+
sortedItems *btree.BTreeG[*workItem]
3627
closed bool
3728
}
3829

3930
func newWorkHeap() *workHeap {
4031
return &workHeap{
32+
innerHeap: heap.NewSet[*workItem](func(a, b *workItem) bool {
33+
return a.priority > b.priority
34+
}),
4135
sortedItems: btree.NewG(
4236
2,
43-
func(a, b *heapItem) bool {
44-
aNothing := a.workItem.start.IsNothing()
45-
bNothing := b.workItem.start.IsNothing()
37+
func(a, b *workItem) bool {
38+
aNothing := a.start.IsNothing()
39+
bNothing := b.start.IsNothing()
4640
if aNothing {
4741
// [a] is Nothing, so if [b] is Nothing, they're equal.
4842
// Otherwise, [b] is greater.
@@ -53,7 +47,7 @@ func newWorkHeap() *workHeap {
5347
return false
5448
}
5549
// [a] and [b] both contain values. Compare the values.
56-
return bytes.Compare(a.workItem.start.Value(), b.workItem.start.Value()) < 0
50+
return bytes.Compare(a.start.Value(), b.start.Value()) < 0
5751
},
5852
),
5953
}
@@ -70,10 +64,8 @@ func (wh *workHeap) Insert(item *workItem) {
7064
return
7165
}
7266

73-
wrappedItem := &heapItem{workItem: item}
74-
75-
heap.Push(&wh.innerHeap, wrappedItem)
76-
wh.sortedItems.ReplaceOrInsert(wrappedItem)
67+
wh.innerHeap.Push(item)
68+
wh.sortedItems.ReplaceOrInsert(item)
7769
}
7870

7971
// Pops and returns a work item from the heap.
@@ -82,9 +74,9 @@ func (wh *workHeap) GetWork() *workItem {
8274
if wh.closed || wh.Len() == 0 {
8375
return nil
8476
}
85-
item := heap.Pop(&wh.innerHeap).(*heapItem)
77+
item, _ := wh.innerHeap.Pop()
8678
wh.sortedItems.Delete(item)
87-
return item.workItem
79+
return item
8880
}
8981

9082
// Insert the item into the heap, merging it with existing items
@@ -99,25 +91,23 @@ func (wh *workHeap) MergeInsert(item *workItem) {
9991
return
10092
}
10193

102-
var mergedBefore, mergedAfter *heapItem
103-
searchItem := &heapItem{
104-
workItem: &workItem{
105-
start: item.start,
106-
},
94+
var mergedBefore, mergedAfter *workItem
95+
searchItem := &workItem{
96+
start: item.start,
10797
}
10898

10999
// Find the item with the greatest start range which is less than [item.start].
110100
// Note that the iterator function will run at most once, since it always returns false.
111101
wh.sortedItems.DescendLessOrEqual(
112102
searchItem,
113-
func(beforeItem *heapItem) bool {
114-
if item.localRootID == beforeItem.workItem.localRootID &&
115-
maybe.Equal(item.start, beforeItem.workItem.end, bytes.Equal) {
103+
func(beforeItem *workItem) bool {
104+
if item.localRootID == beforeItem.localRootID &&
105+
maybe.Equal(item.start, beforeItem.end, bytes.Equal) {
116106
// [beforeItem.start, beforeItem.end] and [item.start, item.end] are
117107
// merged into [beforeItem.start, item.end]
118-
beforeItem.workItem.end = item.end
119-
beforeItem.workItem.priority = math.Max(item.priority, beforeItem.workItem.priority)
120-
heap.Fix(&wh.innerHeap, beforeItem.heapIndex)
108+
beforeItem.end = item.end
109+
beforeItem.priority = math.Max(item.priority, beforeItem.priority)
110+
wh.innerHeap.Fix(beforeItem)
121111
mergedBefore = beforeItem
122112
}
123113
return false
@@ -127,14 +117,14 @@ func (wh *workHeap) MergeInsert(item *workItem) {
127117
// Note that the iterator function will run at most once, since it always returns false.
128118
wh.sortedItems.AscendGreaterOrEqual(
129119
searchItem,
130-
func(afterItem *heapItem) bool {
131-
if item.localRootID == afterItem.workItem.localRootID &&
132-
maybe.Equal(item.end, afterItem.workItem.start, bytes.Equal) {
120+
func(afterItem *workItem) bool {
121+
if item.localRootID == afterItem.localRootID &&
122+
maybe.Equal(item.end, afterItem.start, bytes.Equal) {
133123
// [item.start, item.end] and [afterItem.start, afterItem.end] are merged into
134124
// [item.start, afterItem.end].
135-
afterItem.workItem.start = item.start
136-
afterItem.workItem.priority = math.Max(item.priority, afterItem.workItem.priority)
137-
heap.Fix(&wh.innerHeap, afterItem.heapIndex)
125+
afterItem.start = item.start
126+
afterItem.priority = math.Max(item.priority, afterItem.priority)
127+
wh.innerHeap.Fix(afterItem)
138128
mergedAfter = afterItem
139129
}
140130
return false
@@ -144,12 +134,12 @@ func (wh *workHeap) MergeInsert(item *workItem) {
144134
// we can combine the before item with the after item
145135
if mergedBefore != nil && mergedAfter != nil {
146136
// combine the two ranges
147-
mergedBefore.workItem.end = mergedAfter.workItem.end
137+
mergedBefore.end = mergedAfter.end
148138
// remove the second range since it is now covered by the first
149139
wh.remove(mergedAfter)
150140
// update the priority
151-
mergedBefore.workItem.priority = math.Max(mergedBefore.workItem.priority, mergedAfter.workItem.priority)
152-
heap.Fix(&wh.innerHeap, mergedBefore.heapIndex)
141+
mergedBefore.priority = math.Max(mergedBefore.priority, mergedAfter.priority)
142+
wh.innerHeap.Fix(mergedBefore)
153143
}
154144

155145
// nothing was merged, so add new item to the heap
@@ -160,43 +150,11 @@ func (wh *workHeap) MergeInsert(item *workItem) {
160150
}
161151

162152
// Deletes [item] from the heap.
163-
func (wh *workHeap) remove(item *heapItem) {
164-
heap.Remove(&wh.innerHeap, item.heapIndex)
165-
153+
func (wh *workHeap) remove(item *workItem) {
154+
wh.innerHeap.Remove(item)
166155
wh.sortedItems.Delete(item)
167156
}
168157

169158
func (wh *workHeap) Len() int {
170159
return wh.innerHeap.Len()
171160
}
172-
173-
// below this line are the implementations required for heap.Interface
174-
175-
func (h innerHeap) Len() int {
176-
return len(h)
177-
}
178-
179-
func (h innerHeap) Less(i int, j int) bool {
180-
return h[i].workItem.priority > h[j].workItem.priority
181-
}
182-
183-
func (h innerHeap) Swap(i int, j int) {
184-
h[i], h[j] = h[j], h[i]
185-
h[i].heapIndex = i
186-
h[j].heapIndex = j
187-
}
188-
189-
func (h *innerHeap) Pop() interface{} {
190-
old := *h
191-
n := len(old)
192-
item := old[n-1]
193-
old[n-1] = nil
194-
*h = old[0 : n-1]
195-
return item
196-
}
197-
198-
func (h *innerHeap) Push(x interface{}) {
199-
item := x.(*heapItem)
200-
item.heapIndex = len(*h)
201-
*h = append(*h, item)
202-
}

0 commit comments

Comments
 (0)