Skip to content

feat: add migration checker cmd #1114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
272 changes: 272 additions & 0 deletions cmd/migration-checker/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
package main

import (
"bytes"
"encoding/hex"
"flag"
"fmt"
"os"
"runtime"
"sync"
"sync/atomic"
"time"

"github.com/scroll-tech/go-ethereum/common"
"github.com/scroll-tech/go-ethereum/core/types"
"github.com/scroll-tech/go-ethereum/crypto"
"github.com/scroll-tech/go-ethereum/ethdb/leveldb"
"github.com/scroll-tech/go-ethereum/rlp"
"github.com/scroll-tech/go-ethereum/trie"
)

var accountsDone atomic.Uint64
var trieCheckers chan struct{}

type dbs struct {
zkDb *leveldb.Database
mptDb *leveldb.Database
}

func main() {
var (
mptDbPath = flag.String("mpt-db", "", "path to the MPT node DB")
zkDbPath = flag.String("zk-db", "", "path to the ZK node DB")
mptRoot = flag.String("mpt-root", "", "root hash of the MPT node")
zkRoot = flag.String("zk-root", "", "root hash of the ZK node")
paranoid = flag.Bool("paranoid", false, "verifies all node contents against their expected hash")
parallelismMultipler = flag.Int("parallelism-multiplier", 4, "multiplier for the number of parallel workers")
)
flag.Parse()

zkDb, err := leveldb.New(*zkDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open zk db")
mptDb, err := leveldb.New(*mptDbPath, 1024, 128, "", true)
panicOnError(err, "", "failed to open mpt db")

zkRootHash := common.HexToHash(*zkRoot)
mptRootHash := common.HexToHash(*mptRoot)

numTrieCheckers := runtime.GOMAXPROCS(0) * (*parallelismMultipler)
trieCheckers = make(chan struct{}, numTrieCheckers)
for i := 0; i < numTrieCheckers; i++ {
trieCheckers <- struct{}{}
}

done := make(chan struct{})
totalCheckers := len(trieCheckers)
go func() {
for {
select {
case <-done:
return
case <-time.After(time.Minute):
fmt.Println("Active checkers:", totalCheckers-len(trieCheckers))
}
}
}()
defer close(done)

checkTrieEquality(&dbs{
zkDb: zkDb,
mptDb: mptDb,
}, zkRootHash, mptRootHash, "", checkAccountEquality, true, *paranoid)

for i := 0; i < numTrieCheckers; i++ {
<-trieCheckers
}
}

func panicOnError(err error, label, msg string) {
if err != nil {
panic(fmt.Sprint(label, " error: ", msg, " ", err))
}
}
Comment on lines +73 to +77
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Prefer returning errors instead of calling panicOnError.
This helper function can abruptly terminate the program. For production tools or libraries, returning errors often leads to more flexible handling.


func dup(s []byte) []byte {
return append([]byte{}, s...)
}
func checkTrieEquality(dbs *dbs, zkRoot, mptRoot common.Hash, label string, leafChecker func(string, *dbs, []byte, []byte, bool), top, paranoid bool) {
done := make(chan struct{})
start := time.Now()
if !top {
go func() {
for {
select {
case <-done:
return
case <-time.After(time.Minute):
fmt.Println("Checking trie", label, "for", time.Since(start))
}
}
}()
}
defer close(done)

zkTrie, err := trie.NewZkTrie(zkRoot, trie.NewZktrieDatabaseFromTriedb(trie.NewDatabaseWithConfig(dbs.zkDb, &trie.Config{Preimages: true})))
panicOnError(err, label, "failed to create zk trie")
mptTrie, err := trie.NewSecureNoTracer(mptRoot, trie.NewDatabaseWithConfig(dbs.mptDb, &trie.Config{Preimages: true}))
panicOnError(err, label, "failed to create mpt trie")

mptLeafCh := loadMPT(mptTrie, top)
zkLeafCh := loadZkTrie(zkTrie, top, paranoid)

mptLeafMap := <-mptLeafCh
zkLeafMap := <-zkLeafCh

if len(mptLeafMap) != len(zkLeafMap) {
panic(fmt.Sprintf("%s MPT and ZK trie leaf count mismatch: MPT: %d, ZK: %d", label, len(mptLeafMap), len(zkLeafMap)))
}

for preimageKey, zkValue := range zkLeafMap {
if top {
// ZkTrie pads preimages with 0s to make them 32 bytes.
// So we might need to clear those zeroes here since we need 20 byte addresses at top level (ie state trie)
if len(preimageKey) > 20 {
for _, b := range []byte(preimageKey)[20:] {
if b != 0 {
panic(fmt.Sprintf("%s padded byte is not 0 (preimage %s)", label, hex.EncodeToString([]byte(preimageKey))))
}
}
preimageKey = preimageKey[:20]
}
} else if len(preimageKey) != 32 {
// storage leafs should have 32 byte keys, pad them if needed
zeroes := make([]byte, 32)
copy(zeroes, []byte(preimageKey))
preimageKey = string(zeroes)
}

mptKey := crypto.Keccak256([]byte(preimageKey))
mptVal, ok := mptLeafMap[string(mptKey)]
if !ok {
panic(fmt.Sprintf("%s key %s (preimage %s) not found in mpt", label, hex.EncodeToString(mptKey), hex.EncodeToString([]byte(preimageKey))))
}

leafChecker(fmt.Sprintf("%s key: %s", label, hex.EncodeToString([]byte(preimageKey))), dbs, zkValue, mptVal, paranoid)
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider returning an error for leaf mismatch.

Currently, the function panics if the trie leaf counts or data differ. Exiting abruptly may not be ideal for larger tools or services. Providing retries or partial recoveries is often more user-friendly.


func checkAccountEquality(label string, dbs *dbs, zkAccountBytes, mptAccountBytes []byte, paranoid bool) {
mptAccount := &types.StateAccount{}
panicOnError(rlp.DecodeBytes(mptAccountBytes, mptAccount), label, "failed to decode mpt account")
zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes)
panicOnError(err, label, "failed to decode zk account")

if mptAccount.Nonce != zkAccount.Nonce {
panic(fmt.Sprintf("%s nonce mismatch: zk: %d, mpt: %d", label, zkAccount.Nonce, mptAccount.Nonce))
}

if mptAccount.Balance.Cmp(zkAccount.Balance) != 0 {
panic(fmt.Sprintf("%s balance mismatch: zk: %s, mpt: %s", label, zkAccount.Balance.String(), mptAccount.Balance.String()))
}

if !bytes.Equal(mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash) {
panic(fmt.Sprintf("%s code hash mismatch: zk: %s, mpt: %s", label, hex.EncodeToString(zkAccount.KeccakCodeHash), hex.EncodeToString(mptAccount.KeccakCodeHash)))
}

if (zkAccount.Root == common.Hash{}) != (mptAccount.Root == types.EmptyRootHash) {
panic(fmt.Sprintf("%s empty account root mismatch", label))
} else if zkAccount.Root != (common.Hash{}) {
zkRoot := common.BytesToHash(zkAccount.Root[:])
mptRoot := common.BytesToHash(mptAccount.Root[:])
<-trieCheckers
go func() {
defer func() {
if p := recover(); p != nil {
fmt.Println(p)
os.Exit(1)
}
}()

checkTrieEquality(dbs, zkRoot, mptRoot, label, checkStorageEquality, false, paranoid)
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
trieCheckers <- struct{}{}
}()
} else {
accountsDone.Add(1)
fmt.Println("Accounts done:", accountsDone.Load())
}
}

func checkStorageEquality(label string, _ *dbs, zkStorageBytes, mptStorageBytes []byte, _ bool) {
zkValue := common.BytesToHash(zkStorageBytes)
_, content, _, err := rlp.Split(mptStorageBytes)
panicOnError(err, label, "failed to decode mpt storage")
mptValue := common.BytesToHash(content)
if !bytes.Equal(zkValue[:], mptValue[:]) {
panic(fmt.Sprintf("%s storage mismatch: zk: %s, mpt: %s", label, zkValue.Hex(), mptValue.Hex()))
}
}

func loadMPT(mptTrie *trie.SecureTrie, top bool) chan map[string][]byte {
startKey := make([]byte, 32)
workers := 1 << 5
if !top {
workers = 1 << 3
}
step := byte(0xFF) / byte(workers)

mptLeafMap := make(map[string][]byte, 1000)
var mptLeafMutex sync.Mutex

var mptWg sync.WaitGroup
for i := 0; i < workers; i++ {
startKey[0] = byte(i) * step
trieIt := trie.NewIterator(mptTrie.NodeIterator(startKey))

mptWg.Add(1)
go func() {
defer mptWg.Done()
for trieIt.Next() {
mptLeafMutex.Lock()

if _, ok := mptLeafMap[string(trieIt.Key)]; ok {
mptLeafMutex.Unlock()
break
}

mptLeafMap[string(dup(trieIt.Key))] = dup(trieIt.Value)

mptLeafMutex.Unlock()

if top && len(mptLeafMap)%10000 == 0 {
fmt.Println("MPT Accounts Loaded:", len(mptLeafMap))
}
}
}()
}

respChan := make(chan map[string][]byte)
go func() {
mptWg.Wait()
respChan <- mptLeafMap
}()
return respChan
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Prevent partial load after duplicate key detection.

Currently, when one goroutine encounters a duplicate key and breaks, the others continue loading. This can cause partial or inconsistent data. Consider coordinating with a shared flag or context cancellation so all goroutines stop together.


func loadZkTrie(zkTrie *trie.ZkTrie, top, paranoid bool) chan map[string][]byte {
zkLeafMap := make(map[string][]byte, 1000)
var zkLeafMutex sync.Mutex
zkDone := make(chan map[string][]byte)
go func() {
zkTrie.CountLeaves(func(key, value []byte) {
preimageKey := zkTrie.GetKey(key)
if len(preimageKey) == 0 {
panic(fmt.Sprintf("preimage not found zk trie %s", hex.EncodeToString(key)))
}

zkLeafMutex.Lock()

zkLeafMap[string(dup(preimageKey))] = value

zkLeafMutex.Unlock()

if top && len(zkLeafMap)%10000 == 0 {
fmt.Println("ZK Accounts Loaded:", len(zkLeafMap))
}
}, top, paranoid)
zkDone <- zkLeafMap
}()
return zkDone
}
10 changes: 10 additions & 0 deletions trie/secure_trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ func NewSecure(root common.Hash, db *Database) (*SecureTrie, error) {
return &SecureTrie{trie: *trie, preimages: db.preimages}, nil
}

func NewSecureNoTracer(root common.Hash, db *Database) (*SecureTrie, error) {
t, err := NewSecure(root, db)
if err != nil {
return nil, err
}

t.trie.tracer = nil
return t, nil
}

// Get returns the value for key stored in the trie.
// The value bytes must not be modified by the caller.
func (t *SecureTrie) Get(key []byte) []byte {
Expand Down
24 changes: 24 additions & 0 deletions trie/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ func newTracer() *tracer {
// blob internally. Don't change the value outside of function since
// it's not deep-copied.
func (t *tracer) onRead(path []byte, val []byte) {
if t == nil {
return
}

t.accessList[string(path)] = val
}

// onInsert tracks the newly inserted trie node. If it's already
// in the deletion set (resurrected node), then just wipe it from
// the deletion set as it's "untouched".
func (t *tracer) onInsert(path []byte) {
if t == nil {
return
}

if _, present := t.deletes[string(path)]; present {
delete(t.deletes, string(path))
return
Expand All @@ -78,6 +86,10 @@ func (t *tracer) onInsert(path []byte) {
// in the addition set, then just wipe it from the addition set
// as it's untouched.
func (t *tracer) onDelete(path []byte) {
if t == nil {
return
}

if _, present := t.inserts[string(path)]; present {
delete(t.inserts, string(path))
return
Expand All @@ -87,13 +99,21 @@ func (t *tracer) onDelete(path []byte) {

// reset clears the content tracked by tracer.
func (t *tracer) reset() {
if t == nil {
return
}

t.inserts = make(map[string]struct{})
t.deletes = make(map[string]struct{})
t.accessList = make(map[string][]byte)
}

// copy returns a deep copied tracer instance.
func (t *tracer) copy() *tracer {
if t == nil {
return nil
}

accessList := make(map[string][]byte, len(t.accessList))
for path, blob := range t.accessList {
accessList[path] = common.CopyBytes(blob)
Expand All @@ -107,6 +127,10 @@ func (t *tracer) copy() *tracer {

// deletedNodes returns a list of node paths which are deleted from the trie.
func (t *tracer) deletedNodes() []string {
if t == nil {
return nil
}

var paths []string
for path := range t.deletes {
// It's possible a few deleted nodes were embedded
Expand Down
Loading
Loading