diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 21567a0c9c..70c41d791d 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -154,13 +154,18 @@ type downloadTesterPeer struct { chain *core.BlockChain withholdHeaders map[common.Hash]struct{} + fakeTD *big.Int } // Head constructs a function to retrieve a peer's current head hash // and total difficulty. func (dlp *downloadTesterPeer) Head() (common.Hash, *big.Int, *big.Int) { head := dlp.chain.CurrentBlock() - return head.Hash(), dlp.chain.GetTd(head.Hash(), head.Number.Uint64()), new(big.Int).Set(dlp.chain.CurrentBlock().Difficulty) + td := dlp.chain.GetTd(head.Hash(), head.Number.Uint64()) + if dlp.fakeTD != nil { + td.Set(dlp.fakeTD) + } + return head.Hash(), td, new(big.Int).Set(dlp.chain.CurrentBlock().Difficulty) } // SetHead constructs a function to retrieve a peer's current head hash @@ -992,8 +997,10 @@ func testHighTDStarvationAttack(t *testing.T, protocol uint, mode SyncMode) { defer tester.terminate() chain := testChainBase.shorten(1) - tester.newPeer("attack", protocol, chain.blocks[1:]) - if err := tester.sync("attack", big.NewInt(1000000), mode); err != errStallingPeer { + dlp := tester.newPeer("attack", protocol, chain.blocks[1:]) + fakeTD := big.NewInt(1000000) + dlp.fakeTD = fakeTD + if err := tester.sync("attack", fakeTD, mode); err != errStallingPeer { t.Fatalf("synchronisation error mismatch: have %v, want %v", err, errStallingPeer) } }