Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eth/downloader: retrieve pivot header from local chain if necessary #24610

Merged
merged 5 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion eth/downloader/beaconsync.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,11 @@ func (d *Downloader) fetchBeaconHeaders(from uint64) error {
hashes = make([]common.Hash, 0, maxHeadersProcess)
)
for i := 0; i < maxHeadersProcess && from <= head.Number.Uint64(); i++ {
headers = append(headers, d.skeleton.Header(from))
header := d.skeleton.Header(from)
if header == nil {
header = d.lightchain.GetHeaderByNumber(from)
}
headers = append(headers, header)
hashes = append(hashes, headers[i].Hash())
from++
}
Expand Down
19 changes: 18 additions & 1 deletion eth/downloader/downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ var (
errCanceled = errors.New("syncing canceled (requested)")
errTooOld = errors.New("peer's protocol version too old")
errNoAncestorFound = errors.New("no common ancestor found")
errNoPivotHeader = errors.New("pivot header is not found")
ErrMergeTransition = errors.New("legacy sync reached the merge")
)

Expand Down Expand Up @@ -158,6 +159,9 @@ type LightChain interface {
// GetHeaderByHash retrieves a header from the local chain.
GetHeaderByHash(common.Hash) *types.Header

// GetHeaderByNumber retrieves a block header from the local chain by number.
GetHeaderByNumber(number uint64) *types.Header

// CurrentHeader retrieves the head header from the local chain.
CurrentHeader() *types.Header

Expand Down Expand Up @@ -477,7 +481,20 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td, ttd *
return err
}
if latest.Number.Uint64() > uint64(fsMinFullBlocks) {
pivot = d.skeleton.Header(latest.Number.Uint64() - uint64(fsMinFullBlocks))
number := latest.Number.Uint64() - uint64(fsMinFullBlocks)

// Retrieve the pivot header from the skeleton chain segment but
// fallback to local chain if it's not found in skeleton space.
if pivot = d.skeleton.Header(number); pivot == nil {
pivot = d.lightchain.GetHeaderByNumber(number)
}
// Print an error log and return directly in case the pivot header
// is still not found. It means the skeleton chain is not linked
// correctly with local chain.
if pivot == nil {
log.Error("Pivot header is not found", "number", number)
return errNoPivotHeader
}
}
}
// If no pivot block was returned, the head is below the min full block
Expand Down
55 changes: 54 additions & 1 deletion eth/downloader/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ type downloadTester struct {

// newTester creates a new downloader test mocker.
func newTester() *downloadTester {
return newTesterWithNotification(nil)
}

// newTester creates a new downloader test mocker.
func newTesterWithNotification(success func()) *downloadTester {
freezer, err := ioutil.TempDir("", "")
if err != nil {
panic(err)
Expand All @@ -75,7 +80,7 @@ func newTester() *downloadTester {
chain: chain,
peers: make(map[string]*downloadTesterPeer),
}
tester.downloader = New(0, db, new(event.TypeMux), tester.chain, nil, tester.dropPeer, nil)
tester.downloader = New(0, db, new(event.TypeMux), tester.chain, nil, tester.dropPeer, success)
return tester
}

Expand Down Expand Up @@ -1368,3 +1373,51 @@ func testCheckpointEnforcement(t *testing.T, protocol uint, mode SyncMode) {
assertOwnChain(t, tester, len(chain.blocks))
}
}

// Tests that peers below a pre-configured checkpoint block are prevented from
// being fast-synced from, avoiding potential cheap eclipse attacks.
func TestBeaconSync66Full(t *testing.T) { testBeaconSync(t, eth.ETH66, FullSync) }
func TestBeaconSync66Snap(t *testing.T) { testBeaconSync(t, eth.ETH66, SnapSync) }

func testBeaconSync(t *testing.T, protocol uint, mode SyncMode) {
//log.Root().SetHandler(log.LvlFilterHandler(log.LvlInfo, log.StreamHandler(os.Stderr, log.TerminalFormat(true))))

var cases = []struct {
name string // The name of testing scenario
local int // The length of local chain(canonical chain assumed), 0 means genesis is the head
}{
{name: "Beacon sync since genesis", local: 0},
{name: "Beacon sync with short local chain", local: 1},
{name: "Beacon sync with long local chain", local: blockCacheMaxItems - 15 - fsMinFullBlocks/2},
{name: "Beacon sync with full local chain", local: blockCacheMaxItems - 15 - 1},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
success := make(chan struct{})
tester := newTesterWithNotification(func() {
close(success)
})
defer tester.terminate()

chain := testChainBase.shorten(blockCacheMaxItems - 15)
tester.newPeer("peer", protocol, chain.blocks[1:])

// Build the local chain segment if it's required
if c.local > 0 {
tester.chain.InsertChain(chain.blocks[1 : c.local+1])
}
if err := tester.downloader.BeaconSync(mode, chain.blocks[len(chain.blocks)-1].Header()); err != nil {
t.Fatalf("Failed to beacon sync chain %v %v", c.name, err)
}
select {
case <-success:
// Ok, downloader fully cancelled after sync cycle
if bs := int(tester.chain.CurrentBlock().NumberU64()) + 1; bs != len(chain.blocks) {
t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, len(chain.blocks))
}
case <-time.NewTimer(time.Second * 3).C:
t.Fatalf("Failed to sync chain in three seconds")
}
})
}
}