Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 74 additions & 41 deletions p2p/discover/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
// lookup performs a network search for nodes close to the given target. It approaches the
// target by querying nodes that are closer to it on each iteration. The given target does
// not need to be an actual node identifier.
// lookup on an empty table will return immediately with no nodes.
type lookup struct {
tab *Table
queryfunc queryFunc
Expand All @@ -49,11 +50,15 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l
result: nodesByDistance{target: target},
replyCh: make(chan []*enode.Node, alpha),
cancelCh: ctx.Done(),
queries: -1,
}
// Don't query further if we hit ourself.
// Unlikely to happen often in practice.
it.asked[tab.self().ID()] = true
it.seen[tab.self().ID()] = true

// Initialize the lookup with nodes from table.
closest := it.tab.findnodeByID(it.result.target, bucketSize, false)
it.addNodes(closest.entries)
return it
}

Expand All @@ -64,22 +69,19 @@ func (it *lookup) run() []*enode.Node {
return it.result.entries
}

func (it *lookup) empty() bool {
return len(it.replyBuffer) == 0
}

// advance advances the lookup until any new nodes have been found.
// It returns false when the lookup has ended.
func (it *lookup) advance() bool {
for it.startQueries() {
select {
case nodes := <-it.replyCh:
it.replyBuffer = it.replyBuffer[:0]
for _, n := range nodes {
if n != nil && !it.seen[n.ID()] {
it.seen[n.ID()] = true
it.result.push(n, bucketSize)
it.replyBuffer = append(it.replyBuffer, n)
}
}
it.queries--
if len(it.replyBuffer) > 0 {
it.addNodes(nodes)
if !it.empty() {
return true
}
case <-it.cancelCh:
Expand All @@ -89,6 +91,17 @@ func (it *lookup) advance() bool {
return false
}

func (it *lookup) addNodes(nodes []*enode.Node) {
it.replyBuffer = it.replyBuffer[:0]
for _, n := range nodes {
if n != nil && !it.seen[n.ID()] {
it.seen[n.ID()] = true
it.result.push(n, bucketSize)
it.replyBuffer = append(it.replyBuffer, n)
}
}
}

func (it *lookup) shutdown() {
for it.queries > 0 {
<-it.replyCh
Expand All @@ -103,20 +116,6 @@ func (it *lookup) startQueries() bool {
return false
}

// The first query returns nodes from the local table.
if it.queries == -1 {
closest := it.tab.findnodeByID(it.result.target, bucketSize, false)
// Avoid finishing the lookup too quickly if table is empty. It'd be better to wait
// for the table to fill in this case, but there is no good mechanism for that
// yet.
if len(closest.entries) == 0 {
it.slowdown()
}
it.queries = 1
it.replyCh <- closest.entries
return true
}

// Ask the closest nodes that we haven't asked yet.
for i := 0; i < len(it.result.entries) && it.queries < alpha; i++ {
n := it.result.entries[i]
Expand All @@ -130,15 +129,6 @@ func (it *lookup) startQueries() bool {
return it.queries > 0
}

func (it *lookup) slowdown() {
sleep := time.NewTimer(1 * time.Second)
defer sleep.Stop()
select {
case <-sleep.C:
case <-it.tab.closeReq:
}
}

func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) {
r, err := it.queryfunc(n)
if !errors.Is(err, errClosed) { // avoid recording failures on shutdown.
Expand All @@ -153,12 +143,16 @@ func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) {

// lookupIterator performs lookup operations and iterates over all seen nodes.
// When a lookup finishes, a new one is created through nextLookup.
// LookupIterator waits for table initialization and triggers a table refresh
// when necessary.

type lookupIterator struct {
buffer []*enode.Node
nextLookup lookupFunc
ctx context.Context
cancel func()
lookup *lookup
buffer []*enode.Node
nextLookup lookupFunc
ctx context.Context
cancel func()
lookup *lookup
tabRefreshing <-chan struct{}
}

type lookupFunc func(ctx context.Context) *lookup
Expand All @@ -182,6 +176,7 @@ func (it *lookupIterator) Next() bool {
if len(it.buffer) > 0 {
it.buffer = it.buffer[1:]
}

// Advance the lookup to refill the buffer.
for len(it.buffer) == 0 {
if it.ctx.Err() != nil {
Expand All @@ -191,17 +186,55 @@ func (it *lookupIterator) Next() bool {
}
if it.lookup == nil {
it.lookup = it.nextLookup(it.ctx)
if it.lookup.empty() {
// If the lookup is empty right after creation, it means the local table
// is in a degraded state, and we need to wait for it to fill again.
it.lookupFailed(it.lookup.tab, 1*time.Minute)
it.lookup = nil
continue
}
// Yield the initial nodes from the iterator before advancing the lookup.
it.buffer = it.lookup.replyBuffer
continue
}
if !it.lookup.advance() {

newNodes := it.lookup.advance()
it.buffer = it.lookup.replyBuffer
if !newNodes {
it.lookup = nil
continue
}
it.buffer = it.lookup.replyBuffer
}
return true
}

// lookupFailed handles failed lookup attempts. This can be called when the table has
// exited, or when it runs out of nodes.
func (it *lookupIterator) lookupFailed(tab *Table, timeout time.Duration) {
tout, cancel := context.WithTimeout(it.ctx, timeout)
defer cancel()

// Wait for Table initialization to complete, in case it is still in progress.
select {
case <-tab.initDone:
case <-tout.Done():
return
}

// Wait for ongoing refresh operation, or trigger one.
if it.tabRefreshing == nil {
it.tabRefreshing = tab.refresh()
}
select {
case <-it.tabRefreshing:
it.tabRefreshing = nil
case <-tout.Done():
return
}

// Wait for the table to fill.
tab.waitForNodes(tout, 1)
}

// Close ends the iterator.
func (it *lookupIterator) Close() {
it.cancel()
Expand Down
39 changes: 39 additions & 0 deletions p2p/discover/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/mclock"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p/enode"
Expand Down Expand Up @@ -84,6 +85,7 @@ type Table struct {
closeReq chan struct{}
closed chan struct{}

nodeFeed event.FeedOf[*enode.Node]
nodeAddedHook func(*bucket, *tableNode)
nodeRemovedHook func(*bucket, *tableNode)
}
Expand Down Expand Up @@ -567,6 +569,8 @@ func (tab *Table) nodeAdded(b *bucket, n *tableNode) {
}
n.addedToBucket = time.Now()
tab.revalidation.nodeAdded(tab, n)

tab.nodeFeed.Send(n.Node)
if tab.nodeAddedHook != nil {
tab.nodeAddedHook(b, n)
}
Expand Down Expand Up @@ -702,3 +706,38 @@ func (tab *Table) deleteNode(n *enode.Node) {
b := tab.bucket(n.ID())
tab.deleteInBucket(b, n.ID())
}

// waitForNodes blocks until the table contains at least n nodes.
func (tab *Table) waitForNodes(ctx context.Context, n int) error {
getlength := func() (count int) {
for _, b := range &tab.buckets {
count += len(b.entries)
}
return count
}

var ch chan *enode.Node
for {
tab.mutex.Lock()
if getlength() >= n {
tab.mutex.Unlock()
return nil
}
if ch == nil {
// Init subscription.
ch = make(chan *enode.Node)
sub := tab.nodeFeed.Subscribe(ch)
defer sub.Unsubscribe()
}
tab.mutex.Unlock()

// Wait for a node add event.
select {
case <-ch:
case <-ctx.Done():
return ctx.Err()
case <-tab.closeReq:
return errClosed
}
}
}
23 changes: 16 additions & 7 deletions p2p/discover/v4_udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ import (
"errors"
"fmt"
"io"
"maps"
"math/rand"
"net"
"net/netip"
"reflect"
"slices"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -509,18 +511,26 @@ func TestUDPv4_smallNetConvergence(t *testing.T) {
// they have all found each other.
status := make(chan error, len(nodes))
for i := range nodes {
node := nodes[i]
self := nodes[i]
go func() {
found := make(map[enode.ID]bool, len(nodes))
it := node.RandomNodes()
missing := make(map[enode.ID]bool, len(nodes))
for _, n := range nodes {
if n.Self().ID() == self.Self().ID() {
continue // skip self
}
missing[n.Self().ID()] = true
}

it := self.RandomNodes()
for it.Next() {
found[it.Node().ID()] = true
if len(found) == len(nodes) {
delete(missing, it.Node().ID())
if len(missing) == 0 {
status <- nil
return
}
}
status <- fmt.Errorf("node %s didn't find all nodes", node.Self().ID().TerminalString())
missingIDs := slices.Collect(maps.Keys(missing))
status <- fmt.Errorf("node %s didn't find all nodes, missing %v", self.Self().ID().TerminalString(), missingIDs)
}()
}

Expand All @@ -537,7 +547,6 @@ func TestUDPv4_smallNetConvergence(t *testing.T) {
received++
if err != nil {
t.Error("ERROR:", err)
return
}
}
}
Expand Down