Skip to content

Commit

Permalink
eth/downloader: remove header rollback mechanism #28147
Browse files Browse the repository at this point in the history
  • Loading branch information
joeylichang committed Oct 12, 2023
1 parent 68722d0 commit 77a80cd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 130 deletions.
91 changes: 41 additions & 50 deletions eth/downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ var (
errCanceled = errors.New("syncing canceled (requested)")
errTooOld = errors.New("peer's protocol version too old")
errNoAncestorFound = errors.New("no common ancestor found")
ErrMergeTransition = errors.New("legacy sync reached the merge")
)

// peerDropFn is a callback type for dropping a peer detected as malicious.
Expand Down Expand Up @@ -1231,45 +1232,20 @@ func (d *Downloader) fetchReceipts(from uint64, beaconMode bool) error {
return err
}

// processHeaders takes batches of retrieved headers from an input channel and
// keeps processing and scheduling them into the header chain and downloader's
// queue until the stream ends or a failure occurs.
// processHeaders takes batches of retrieved headers from an input channel and
// keeps processing and scheduling them into the header chain and downloader's
// queue until the stream ends or a failure occurs.
func (d *Downloader) processHeaders(origin uint64, td, ttd *big.Int, beaconMode bool) error {
// Keep a count of uncertain headers to roll back
var (
rollback uint64 // Zero means no rollback (fine as you can't unroll the genesis)
rollbackErr error
mode = d.getMode()
mode = d.getMode()
gotHeaders = false // Wait for batches of headers to process
)
defer func() {
if rollback > 0 {
lastHeader, lastFastBlock, lastBlock := d.lightchain.CurrentHeader().Number, common.Big0, common.Big0
if mode != LightSync {
lastFastBlock = d.blockchain.CurrentSnapBlock().Number
lastBlock = d.blockchain.CurrentBlock().Number
}
if err := d.lightchain.SetHead(rollback - 1); err != nil { // -1 to target the parent of the first uncertain block
// We're already unwinding the stack, only print the error to make it more visible
log.Error("Failed to roll back chain segment", "head", rollback-1, "err", err)
}
curFastBlock, curBlock := common.Big0, common.Big0
if mode != LightSync {
curFastBlock = d.blockchain.CurrentSnapBlock().Number
curBlock = d.blockchain.CurrentBlock().Number
}
log.Warn("Rolled back chain segment",
"header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number),
"snap", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock),
"block", fmt.Sprintf("%d->%d", lastBlock, curBlock), "reason", rollbackErr)
}
}()
// Wait for batches of headers to process
gotHeaders := false

for {
select {
case <-d.cancelCh:
rollbackErr = errCanceled
return errCanceled

case task := <-d.headerProcCh:
Expand Down Expand Up @@ -1318,8 +1294,6 @@ func (d *Downloader) processHeaders(origin uint64, td, ttd *big.Int, beaconMode
}
}
}
// Disable any rollback and return
rollback = 0
return nil
}
// Otherwise split the chunk of headers into batches and process them
Expand All @@ -1330,7 +1304,6 @@ func (d *Downloader) processHeaders(origin uint64, td, ttd *big.Int, beaconMode
// Terminate if something failed in between processing chunks
select {
case <-d.cancelCh:
rollbackErr = errCanceled
return errCanceled
default:
}
Expand All @@ -1344,26 +1317,46 @@ func (d *Downloader) processHeaders(origin uint64, td, ttd *big.Int, beaconMode

// In case of header only syncing, validate the chunk immediately
if mode == SnapSync || mode == LightSync {
// Although the received headers might be all valid, a legacy
// PoW/PoA sync must not accept post-merge headers. Make sure
// that any transition is rejected at this point.
var (
rejected []*types.Header
td *big.Int
)
if !beaconMode && ttd != nil {
td = d.blockchain.GetTd(chunkHeaders[0].ParentHash, chunkHeaders[0].Number.Uint64()-1)
if td == nil {
// This should never really happen, but handle gracefully for now
log.Error("Failed to retrieve parent header TD", "number", chunkHeaders[0].Number.Uint64()-1, "hash", chunkHeaders[0].ParentHash)
return fmt.Errorf("%w: parent TD missing", errInvalidChain)
}
for i, header := range chunkHeaders {
td = new(big.Int).Add(td, header.Difficulty)
if td.Cmp(ttd) >= 0 {
// Terminal total difficulty reached, allow the last header in
if new(big.Int).Sub(td, header.Difficulty).Cmp(ttd) < 0 {
chunkHeaders, rejected = chunkHeaders[:i+1], chunkHeaders[i+1:]
if len(rejected) > 0 {
// Make a nicer user log as to the first TD truly rejected
td = new(big.Int).Add(td, rejected[0].Difficulty)
}
} else {
chunkHeaders, rejected = chunkHeaders[:i], chunkHeaders[i:]
}
break
}
}
}
if len(chunkHeaders) > 0 {
if n, err := d.lightchain.InsertHeaderChain(chunkHeaders); err != nil {
rollbackErr = err

// If some headers were inserted, track them as uncertain
if mode == SnapSync && n > 0 && rollback == 0 {
rollback = chunkHeaders[0].Number.Uint64()
}
log.Warn("Invalid header encountered", "number", chunkHeaders[n].Number, "hash", chunkHashes[n], "parent", chunkHeaders[n].ParentHash, "err", err)
return fmt.Errorf("%w: %v", errInvalidChain, err)
}
// All verifications passed, track all headers within the alloted limits
if mode == SnapSync {
head := chunkHeaders[len(chunkHeaders)-1].Number.Uint64()
if head-rollback > uint64(fsHeaderSafetyNet) {
rollback = head - uint64(fsHeaderSafetyNet)
} else {
rollback = 1
}
}
}
if len(rejected) != 0 {
log.Info("Legacy sync reached merge threshold", "number", rejected[0].Number, "hash", rejected[0].Hash(), "td", td, "ttd", ttd)
return ErrMergeTransition
}
}
// Unless we're doing light chains, schedule the headers for associated content retrieval
Expand All @@ -1372,15 +1365,13 @@ func (d *Downloader) processHeaders(origin uint64, td, ttd *big.Int, beaconMode
for d.queue.PendingBodies() >= maxQueuedHeaders || d.queue.PendingReceipts() >= maxQueuedHeaders {
select {
case <-d.cancelCh:
rollbackErr = errCanceled
return errCanceled
case <-time.After(time.Second):
}
}
// Otherwise insert the headers for content retrieval
inserts := d.queue.Schedule(chunkHeaders, chunkHashes, origin)
if len(inserts) != len(chunkHeaders) {
rollbackErr = fmt.Errorf("stale headers: len inserts %v len(chunk) %v", len(inserts), len(chunkHeaders))
return fmt.Errorf("%w: stale headers", errBadPeer)
}
}
Expand Down
80 changes: 0 additions & 80 deletions eth/downloader/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -881,86 +881,6 @@ func testShiftedHeaderAttack(t *testing.T, protocol uint, mode SyncMode) {
assertOwnChain(t, tester, len(chain.blocks))
}

// Tests that upon detecting an invalid header, the recent ones are rolled back
// for various failure scenarios. Afterwards a full sync is attempted to make
// sure no state was corrupted.
func TestInvalidHeaderRollback66Snap(t *testing.T) { testInvalidHeaderRollback(t, eth.ETH66, SnapSync) }
func TestInvalidHeaderRollback67Snap(t *testing.T) { testInvalidHeaderRollback(t, eth.ETH67, SnapSync) }

func testInvalidHeaderRollback(t *testing.T, protocol uint, mode SyncMode) {
tester := newTester(t)
defer tester.terminate()

// Create a small enough block chain to download
targetBlocks := 3*fsHeaderSafetyNet + 256 + fsMinFullBlocks
chain := testChainBase.shorten(targetBlocks)

// Attempt to sync with an attacker that feeds junk during the fast sync phase.
// This should result in the last fsHeaderSafetyNet headers being rolled back.
missing := fsHeaderSafetyNet + MaxHeaderFetch + 1

fastAttacker := tester.newPeer("fast-attack", protocol, chain.blocks[1:])
fastAttacker.withholdHeaders[chain.blocks[missing].Hash()] = struct{}{}

if err := tester.sync("fast-attack", nil, mode); err == nil {
t.Fatalf("succeeded fast attacker synchronisation")
}
if head := tester.chain.CurrentHeader().Number.Int64(); int(head) > MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, MaxHeaderFetch)
}
// Attempt to sync with an attacker that feeds junk during the block import phase.
// This should result in both the last fsHeaderSafetyNet number of headers being
// rolled back, and also the pivot point being reverted to a non-block status.
missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1

blockAttacker := tester.newPeer("block-attack", protocol, chain.blocks[1:])
fastAttacker.withholdHeaders[chain.blocks[missing].Hash()] = struct{}{} // Make sure the fast-attacker doesn't fill in
blockAttacker.withholdHeaders[chain.blocks[missing].Hash()] = struct{}{}

if err := tester.sync("block-attack", nil, mode); err == nil {
t.Fatalf("succeeded block attacker synchronisation")
}
if head := tester.chain.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
}
if mode == SnapSync {
if head := tester.chain.CurrentBlock().Number.Uint64(); head != 0 {
t.Errorf("fast sync pivot block #%d not rolled back", head)
}
}
// Attempt to sync with an attacker that withholds promised blocks after the
// fast sync pivot point. This could be a trial to leave the node with a bad
// but already imported pivot block.
withholdAttacker := tester.newPeer("withhold-attack", protocol, chain.blocks[1:])

tester.downloader.syncInitHook = func(uint64, uint64) {
for i := missing; i < len(chain.blocks); i++ {
withholdAttacker.withholdHeaders[chain.blocks[i].Hash()] = struct{}{}
}
tester.downloader.syncInitHook = nil
}
if err := tester.sync("withhold-attack", nil, mode); err == nil {
t.Fatalf("succeeded withholding attacker synchronisation")
}
if head := tester.chain.CurrentHeader().Number.Int64(); int(head) > 2*fsHeaderSafetyNet+MaxHeaderFetch {
t.Errorf("rollback head mismatch: have %v, want at most %v", head, 2*fsHeaderSafetyNet+MaxHeaderFetch)
}
if mode == SnapSync {
if head := tester.chain.CurrentBlock().Number.Uint64(); head != 0 {
t.Errorf("fast sync pivot block #%d not rolled back", head)
}
}
// Synchronise with the valid peer and make sure sync succeeds. Since the last rollback
// should also disable fast syncing for this process, verify that we did a fresh full
// sync. Note, we can't assert anything about the receipts since we won't purge the
// database of them, hence we can't use assertOwnChain.
tester.newPeer("valid", protocol, chain.blocks[1:])
if err := tester.sync("valid", nil, mode); err != nil {
t.Fatalf("failed to synchronise blocks: %v", err)
}
assertOwnChain(t, tester, len(chain.blocks))
}

// Tests that a peer advertising a high TD doesn't get to stall the downloader
// afterwards by not sending any useful hashes.
func TestHighTDStarvationAttack66Full(t *testing.T) {
Expand Down

0 comments on commit 77a80cd

Please sign in to comment.