Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core/state: semantic journalling (part 1) #28880

Merged
merged 9 commits into from
Aug 28, 2024
164 changes: 129 additions & 35 deletions core/state/journal.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,21 @@
package state

import (
"fmt"
"maps"
"slices"
"sort"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/holiman/uint256"
)

type revision struct {
id int
journalIndex int
}

// journalEntry is a modification entry in the state change journal that can be
// reverted on demand.
type journalEntry interface {
Expand All @@ -42,6 +51,9 @@ type journalEntry interface {
type journal struct {
entries []journalEntry // Current changes tracked by the journal
dirties map[common.Address]int // Dirty accounts and the number of changes

validRevisions []revision
nextRevisionId int
holiman marked this conversation as resolved.
Show resolved Hide resolved
}

// newJournal creates a new initialized journal.
Expand All @@ -51,6 +63,40 @@ func newJournal() *journal {
}
}

// Reset clears the journal, after this operation the journal can be used
// anew. It is semantically similar to calling 'newJournal', but the underlying
// slices can be reused
func (j *journal) Reset() {
j.entries = j.entries[:0]
j.validRevisions = j.validRevisions[:0]
clear(j.dirties)
j.nextRevisionId = 0
}

// Snapshot returns an identifier for the current revision of the state.
func (j *journal) Snapshot() int {
id := j.nextRevisionId
j.nextRevisionId++
j.validRevisions = append(j.validRevisions, revision{id, j.length()})
return id
}

// RevertToSnapshot reverts all state changes made since the given revision.
func (j *journal) RevertToSnapshot(revid int, s *StateDB) {
// Find the snapshot in the stack of valid snapshots.
idx := sort.Search(len(j.validRevisions), func(i int) bool {
return j.validRevisions[i].id >= revid
})
if idx == len(j.validRevisions) || j.validRevisions[idx].id != revid {
panic(fmt.Errorf("revision id %v cannot be reverted", revid))
}
snapshot := j.validRevisions[idx].journalIndex

// Replay the journal to undo changes and remove invalidated snapshots
j.revert(s, snapshot)
j.validRevisions = j.validRevisions[:idx]
}

// append inserts a new modification entry to the end of the change journal.
func (j *journal) append(entry journalEntry) {
j.entries = append(j.entries, entry)
Expand Down Expand Up @@ -95,8 +141,83 @@ func (j *journal) copy() *journal {
entries = append(entries, j.entries[i].copy())
}
return &journal{
entries: entries,
dirties: maps.Clone(j.dirties),
entries: entries,
dirties: maps.Clone(j.dirties),
validRevisions: slices.Clone(j.validRevisions),
nextRevisionId: j.nextRevisionId,
}
}

func (j *journal) AccessListAddAccount(addr common.Address) {
j.append(accessListAddAccountChange{&addr})
}

func (j *journal) AccessListAddSlot(addr common.Address, slot common.Hash) {
j.append(accessListAddSlotChange{
address: &addr,
slot: &slot,
})
}

func (j *journal) Log(txHash common.Hash) {
j.append(addLogChange{txhash: txHash})
}

func (j *journal) Create(addr common.Address) {
j.append(createObjectChange{account: &addr})
}

func (j *journal) Destruct(addr common.Address) {
j.append(selfDestructChange{account: &addr})
}

func (j *journal) SetStorage(addr common.Address, key, prev, origin common.Hash) {
j.append(storageChange{
account: &addr,
key: key,
prevvalue: prev,
origvalue: origin,
})
}

func (j *journal) SetTransientState(addr common.Address, key, prev common.Hash) {
j.append(transientStorageChange{
account: &addr,
key: key,
prevalue: prev,
})
}

func (j *journal) RefundChange(previous uint64) {
j.append(refundChange{prev: previous})
}

func (j *journal) BalanceChange(addr common.Address, previous *uint256.Int) {
j.append(balanceChange{
account: &addr,
prev: previous.Clone(),
})
}

func (j *journal) SetCode(address common.Address) {
j.append(codeChange{account: &address})
}

func (j *journal) NonceChange(address common.Address, prev uint64) {
j.append(nonceChange{
account: &address,
prev: prev,
})
}

func (j *journal) Touch(address common.Address) {
j.append(touchChange{
account: &address,
})
if address == ripemd {
// Explicitly put it in the dirty-cache, which is otherwise generated from
// flattened journals.
j.dirty(address)
}
}

Expand All @@ -114,9 +235,7 @@ type (
}

selfDestructChange struct {
account *common.Address
prev bool // whether account had already self-destructed
prevbalance *uint256.Int
account *common.Address
}

// Changes to individual accounts.
Expand All @@ -135,8 +254,7 @@ type (
origvalue common.Hash
}
codeChange struct {
account *common.Address
prevcode, prevhash []byte
account *common.Address
}

// Changes to other state values.
Expand All @@ -146,9 +264,6 @@ type (
addLogChange struct {
txhash common.Hash
}
addPreimageChange struct {
hash common.Hash
}
touchChange struct {
account *common.Address
}
Expand Down Expand Up @@ -200,8 +315,7 @@ func (ch createContractChange) copy() journalEntry {
func (ch selfDestructChange) revert(s *StateDB) {
obj := s.getStateObject(*ch.account)
if obj != nil {
obj.selfDestructed = ch.prev
obj.setBalance(ch.prevbalance)
obj.selfDestructed = false
}
}

Expand All @@ -211,9 +325,7 @@ func (ch selfDestructChange) dirtied() *common.Address {

func (ch selfDestructChange) copy() journalEntry {
return selfDestructChange{
account: ch.account,
prev: ch.prev,
prevbalance: new(uint256.Int).Set(ch.prevbalance),
account: ch.account,
}
}

Expand Down Expand Up @@ -263,19 +375,15 @@ func (ch nonceChange) copy() journalEntry {
}

func (ch codeChange) revert(s *StateDB) {
s.getStateObject(*ch.account).setCode(common.BytesToHash(ch.prevhash), ch.prevcode)
s.getStateObject(*ch.account).setCode(types.EmptyCodeHash, nil)
}

func (ch codeChange) dirtied() *common.Address {
return ch.account
}

func (ch codeChange) copy() journalEntry {
return codeChange{
account: ch.account,
prevhash: common.CopyBytes(ch.prevhash),
prevcode: common.CopyBytes(ch.prevcode),
}
return codeChange{account: ch.account}
}

func (ch storageChange) revert(s *StateDB) {
Expand Down Expand Up @@ -344,20 +452,6 @@ func (ch addLogChange) copy() journalEntry {
}
}

func (ch addPreimageChange) revert(s *StateDB) {
delete(s.preimages, ch.hash)
}

func (ch addPreimageChange) dirtied() *common.Address {
return nil
}

func (ch addPreimageChange) copy() journalEntry {
return addPreimageChange{
hash: ch.hash,
}
}

func (ch accessListAddAccountChange) revert(s *StateDB) {
/*
One important invariant here, is that whenever a (addr, slot) is added, if the
Expand Down
38 changes: 8 additions & 30 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,7 @@ func (s *stateObject) markSelfdestructed() {
}

func (s *stateObject) touch() {
s.db.journal.append(touchChange{
account: &s.address,
})
if s.address == ripemd {
// Explicitly put it in the dirty-cache, which is otherwise generated from
// flattened journals.
s.db.journal.dirty(s.address)
}
s.db.journal.Touch(s.address)
}

// getTrie returns the associated storage trie. The trie will be opened if it's
Expand Down Expand Up @@ -251,16 +244,11 @@ func (s *stateObject) SetState(key, value common.Hash) {
return
}
// New value is different, update and journal the change
s.db.journal.append(storageChange{
account: &s.address,
key: key,
prevvalue: prev,
origvalue: origin,
})
s.db.journal.SetStorage(s.address, key, prev, origin)
s.setState(key, value, origin)
if s.db.logger != nil && s.db.logger.OnStorageChange != nil {
s.db.logger.OnStorageChange(s.address, key, prev, value)
}
s.setState(key, value, origin)
}

// setState updates a value in account dirty storage. The dirtiness will be
Expand Down Expand Up @@ -510,10 +498,7 @@ func (s *stateObject) SubBalance(amount *uint256.Int, reason tracing.BalanceChan
}

func (s *stateObject) SetBalance(amount *uint256.Int, reason tracing.BalanceChangeReason) {
s.db.journal.append(balanceChange{
account: &s.address,
prev: new(uint256.Int).Set(s.data.Balance),
})
s.db.journal.BalanceChange(s.address, s.data.Balance)
if s.db.logger != nil && s.db.logger.OnBalanceChange != nil {
s.db.logger.OnBalanceChange(s.address, s.Balance().ToBig(), amount.ToBig(), reason)
}
Expand Down Expand Up @@ -589,14 +574,10 @@ func (s *stateObject) CodeSize() int {
}

func (s *stateObject) SetCode(codeHash common.Hash, code []byte) {
prevcode := s.Code()
s.db.journal.append(codeChange{
account: &s.address,
prevhash: s.CodeHash(),
prevcode: prevcode,
})
s.db.journal.SetCode(s.address)
rjl493456442 marked this conversation as resolved.
Show resolved Hide resolved
if s.db.logger != nil && s.db.logger.OnCodeChange != nil {
s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), prevcode, codeHash, code)
// TODO remove prevcode from this callback
s.db.logger.OnCodeChange(s.address, common.BytesToHash(s.CodeHash()), nil, codeHash, code)
}
s.setCode(codeHash, code)
}
Expand All @@ -608,10 +589,7 @@ func (s *stateObject) setCode(codeHash common.Hash, code []byte) {
}

func (s *stateObject) SetNonce(nonce uint64) {
s.db.journal.append(nonceChange{
account: &s.address,
prev: s.data.Nonce,
})
s.db.journal.NonceChange(s.address, s.data.Nonce)
if s.db.logger != nil && s.db.logger.OnNonceChange != nil {
s.db.logger.OnNonceChange(s.address, s.data.Nonce, nonce)
}
Expand Down
Loading