Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditionally allocate WaitGroup memory #2901

Merged
merged 22 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 65 additions & 6 deletions x/merkledb/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,66 @@ func (v *view) hashChangedNodes(ctx context.Context) {
// Calculates the ID of all descendants of [n] which need to be recalculated,
// and then calculates the ID of [n] itself.
func (v *view) hashChangedNode(n *node) ids.ID {
// We use [wg] to wait until all descendants of [n] have been updated.
var wg sync.WaitGroup
// If there are no children, we can avoid allocating [keyBuffer].
if len(n.children) == 0 {
return n.calculateID(v.db.metrics)
}

// Calculate the size of the largest child key of this node. This allows
// only allocating a single slice for all of the keys.
var maxChildBitLength int
for _, childEntry := range n.children {
maxChildBitLength = max(maxChildBitLength, childEntry.compressedKey.length)
}

var (
maxBytesNeeded = bytesNeeded(n.key.length + v.tokenSize + maxChildBitLength)
// keyBuffer is allocated onto the heap because it is dynamically sized.
keyBuffer = make([]byte, maxBytesNeeded)
// childBuffer is allocated on the stack.
childBuffer = make([]byte, 1)
dualIndex = dualBitIndex(v.tokenSize)
bytesForKey = bytesNeeded(n.key.length)
// We track the last byte of [n.key] so that we can reset the value for
// each key. This is needed because the child buffer may get ORed at
// this byte.
lastKeyByte byte

// We use [wg] to wait until all descendants of [n] have been updated.
wg waitGroup
)

if bytesForKey > 0 {
// We only need to copy this node's key once because it does not change
// as we iterate over the children.
copy(keyBuffer, n.key.value)
lastKeyByte = keyBuffer[bytesForKey-1]
}

for childIndex, childEntry := range n.children {
childEntry := childEntry // New variable so goroutine doesn't capture loop variable.
childKey := n.key.Extend(ToToken(childIndex, v.tokenSize), childEntry.compressedKey)
childBuffer[0] = childIndex << dualIndex
childIndexAsKey := Key{
// It is safe to use byteSliceToString because [childBuffer] is not
// modified while [childIndexAsKey] is in use.
value: byteSliceToString(childBuffer),
length: v.tokenSize,
}

totalBitLength := n.key.length + v.tokenSize + childEntry.compressedKey.length
buffer := keyBuffer[:bytesNeeded(totalBitLength)]
// Make sure the last byte of the key is originally set correctly
if bytesForKey > 0 {
buffer[bytesForKey-1] = lastKeyByte
}
extendIntoBuffer(buffer, childIndexAsKey, n.key.length)
extendIntoBuffer(buffer, childEntry.compressedKey, n.key.length+v.tokenSize)
childKey := Key{
// It is safe to use byteSliceToString because [buffer] is not
// modified while [childKey] is in use.
value: byteSliceToString(buffer),
length: totalBitLength,
}

childNodeChange, ok := v.changes.nodes[childKey]
if !ok {
// This child wasn't changed.
Expand All @@ -306,11 +360,16 @@ func (v *view) hashChangedNode(n *node) ids.ID {
// Try updating the child and its descendants in a goroutine.
if ok := v.db.hashNodesSema.TryAcquire(1); ok {
wg.Add(1)
go func() {

// Passing variables explicitly through the function call rather
// than implicitly passing them through the scope of the function
// definition allows the passed variables to be allocated on the
// stack.
go func(wg *sync.WaitGroup, childEntry *child) {
childEntry.id = v.hashChangedNode(childNodeChange.after)
v.db.hashNodesSema.Release(1)
wg.Done()
}()
}(wg.wg, childEntry)
} else {
// We're at the goroutine limit; do the work in this goroutine.
childEntry.id = v.hashChangedNode(childNodeChange.after)
Expand Down
105 changes: 105 additions & 0 deletions x/merkledb/view_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package merkledb

import (
"context"
"encoding/binary"
"testing"

"github.com/stretchr/testify/require"

"github.com/ava-labs/avalanchego/database"
"github.com/ava-labs/avalanchego/database/memdb"
"github.com/ava-labs/avalanchego/utils/hashing"
)

var hashChangedNodesTests = []struct {
name string
numKeys uint64
expectedRootHash string
}{
{
name: "1",
numKeys: 1,
expectedRootHash: "2A4DRkSWbTvSxgA1UMGp1Mpt1yzMFaeMMiDnrijVGJXPcRYiD4",
},
{
name: "10",
numKeys: 10,
expectedRootHash: "2PGy7QvbYwVwn5QmLgj4KBgV2BisanZE8Nue2SxK9ffybb4mAn",
},
{
name: "100",
numKeys: 100,
expectedRootHash: "LCeS4DWh6TpNKWH4ke9a2piSiwwLbmxGUj8XuaWx1XDGeCMAv",
},
{
name: "1000",
numKeys: 1000,
expectedRootHash: "2S6f84wdRHmnx51mj35DF2owzf8wio5pzNJXfEWfFYFNxUB64T",
},
{
name: "10000",
numKeys: 10000,
expectedRootHash: "wF6UnhaDoA9fAqiXAcx27xCYBK2aspDBEXkicmC7rs8EzLCD8",
},
{
name: "100000",
numKeys: 100000,
expectedRootHash: "2Dy3RWZeNDUnUvzXpruB5xdp1V7xxb14M53ywdZVACDkdM66M1",
},
}

func makeViewForHashChangedNodes(t require.TestingT, numKeys uint64, parallelism uint) *view {
config := newDefaultConfig()
config.RootGenConcurrency = parallelism
db, err := newDatabase(
context.Background(),
memdb.New(),
config,
&mockMetrics{},
)
require.NoError(t, err)

ops := make([]database.BatchOp, 0, numKeys)
for i := uint64(0); i < numKeys; i++ {
k := binary.AppendUvarint(nil, i)
ops = append(ops, database.BatchOp{
Key: k,
Value: hashing.ComputeHash256(k),
})
}

ctx := context.Background()
viewIntf, err := db.NewView(ctx, ViewChanges{BatchOps: ops})
require.NoError(t, err)

view := viewIntf.(*view)
require.NoError(t, view.calculateNodeChanges(ctx))
return view
}

func Test_HashChangedNodes(t *testing.T) {
for _, test := range hashChangedNodesTests {
t.Run(test.name, func(t *testing.T) {
view := makeViewForHashChangedNodes(t, test.numKeys, 16)
ctx := context.Background()
view.hashChangedNodes(ctx)
require.Equal(t, test.expectedRootHash, view.changes.rootID.String())
})
}
}

func Benchmark_HashChangedNodes(b *testing.B) {
for _, test := range hashChangedNodesTests {
view := makeViewForHashChangedNodes(b, test.numKeys, 1)
ctx := context.Background()
b.Run(test.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
view.hashChangedNodes(ctx)
}
})
}
}
25 changes: 25 additions & 0 deletions x/merkledb/wait_group.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package merkledb

import "sync"

// waitGroup is a small wrapper of a sync.WaitGroup that avoids performing a
// memory allocation when Add is never called.
type waitGroup struct {
wg *sync.WaitGroup
}

func (wg *waitGroup) Add(delta int) {
if wg.wg == nil {
wg.wg = new(sync.WaitGroup)
}
wg.wg.Add(delta)
}

func (wg *waitGroup) Wait() {
if wg.wg != nil {
wg.wg.Wait()
}
}
29 changes: 29 additions & 0 deletions x/merkledb/wait_group_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package merkledb

import "testing"

func Benchmark_WaitGroup_Wait(b *testing.B) {
for i := 0; i < b.N; i++ {
var wg waitGroup
wg.Wait()
}
}

func Benchmark_WaitGroup_Add(b *testing.B) {
for i := 0; i < b.N; i++ {
var wg waitGroup
wg.Add(1)
}
}

func Benchmark_WaitGroup_AddDoneWait(b *testing.B) {
for i := 0; i < b.N; i++ {
var wg waitGroup
wg.Add(1)
wg.wg.Done()
wg.Wait()
}
}
Loading