Skip to content

Commit

Permalink
Added Optimizations to Beacon state transition. (erigontech#6792)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giulio2002 authored Feb 7, 2023
1 parent 683f022 commit a585ae4
Show file tree
Hide file tree
Showing 21 changed files with 341 additions and 236 deletions.
4 changes: 4 additions & 0 deletions cl/cltypes/eth1_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ type Eth1Data struct {
DepositCount uint64
}

func (e *Eth1Data) Equal(b *Eth1Data) bool {
return e.BlockHash == b.BlockHash && e.Root == b.Root && b.DepositCount == e.DepositCount
}

// MarshalSSZTo ssz marshals the Eth1Data object to a target array
func (e *Eth1Data) EncodeSSZ(buf []byte) (dst []byte, err error) {
dst = buf
Expand Down
10 changes: 10 additions & 0 deletions cl/utils/bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,13 @@ func DecompressZstd(b []byte) ([]byte, error) {
}
return r.DecodeAll(b, nil)
}

// Check if it is sorted and check if there are duplicates. O(N) complexity.
func IsSliceSortedSet(vals []uint64) bool {
for i := 0; i < len(vals)-1; i++ {
if vals[i] >= vals[i+1] {
return false
}
}
return true
}
19 changes: 18 additions & 1 deletion cl/utils/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ import (
"sync"
)

type HashFunc func(data []byte, extras ...[]byte) [32]byte

var hasherPool = sync.Pool{
New: func() interface{} {
return sha256.New()
},
}

// General purpose Keccak256
func Keccak256(data []byte, extras ...[]byte) [32]byte {
h, ok := hasherPool.Get().(hash.Hash)
if !ok {
Expand All @@ -40,6 +43,20 @@ func Keccak256(data []byte, extras ...[]byte) [32]byte {
h.Write(extra)
}
h.Sum(b[:0])

return b
}

// Optimized Keccak256, avoid pool.put/pool.get, meant for intensive operations.
func OptimizedKeccak256() HashFunc {
h := sha256.New()
return func(data []byte, extras ...[]byte) [32]byte {
h.Reset()
var b [32]byte
h.Write(data)
for _, extra := range extras {
h.Write(extra)
}
h.Sum(b[:0])
return b
}
}
84 changes: 58 additions & 26 deletions cmd/erigon-cl/core/state/accessors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,21 @@ import (
"github.com/ledgerwatch/erigon/cl/cltypes"
"github.com/ledgerwatch/erigon/cl/fork"
"github.com/ledgerwatch/erigon/cl/utils"
eth2_shuffle "github.com/protolambda/eth2-shuffle"
)

// GetActiveValidatorsIndices returns the list of validator indices active for the given epoch.
func (b *BeaconState) GetActiveValidatorsIndices(epoch uint64) (indicies []uint64) {
if cachedIndicies, ok := b.activeValidatorsCache.Get(epoch); ok {
return cachedIndicies.([]uint64)
}
for i, validator := range b.validators {
if !validator.Active(epoch) {
continue
}
indicies = append(indicies, uint64(i))
}
b.activeValidatorsCache.Add(epoch, indicies)
return
}

Expand Down Expand Up @@ -89,7 +94,10 @@ func (b *BeaconState) GetTotalBalance(validatorSet []uint64) (uint64, error) {

// GetTotalActiveBalance return the sum of all balances within active validators.
func (b *BeaconState) GetTotalActiveBalance() (uint64, error) {
return b.GetTotalBalance(b.GetActiveValidatorsIndices(b.Epoch()))
if b.totalActiveBalanceCache < b.beaconConfig.EffectiveBalanceIncrement {
return b.beaconConfig.EffectiveBalanceIncrement, nil
}
return b.totalActiveBalanceCache, nil
}

// GetTotalSlashingAmount return the sum of all slashings.
Expand Down Expand Up @@ -129,18 +137,24 @@ func (b *BeaconState) GetDomain(domainType [4]byte, epoch uint64) ([]byte, error
return fork.ComputeDomain(domainType[:], forkVersion, b.genesisValidatorsRoot)
}

func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte) (uint64, error) {
func (b *BeaconState) ComputeShuffledIndexPreInputs(seed [32]byte) [][32]byte {
ret := make([][32]byte, b.beaconConfig.ShuffleRoundCount)
for i := range ret {
ret[i] = utils.Keccak256(append(seed[:], byte(i)))
}
return ret
}

func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte, preInputs [][32]byte, hashFunc utils.HashFunc) (uint64, error) {
if ind >= ind_count {
return 0, fmt.Errorf("index=%d must be less than the index count=%d", ind, ind_count)
}

if len(preInputs) == 0 {
preInputs = b.ComputeShuffledIndexPreInputs(seed)
}
for i := uint64(0); i < b.beaconConfig.ShuffleRoundCount; i++ {
// Construct first hash input.
input := append(seed[:], byte(i))
hashedInput := utils.Keccak256(input)

// Read hash value.
hashValue := binary.LittleEndian.Uint64(hashedInput[:8])
hashValue := binary.LittleEndian.Uint64(preInputs[i][:8])

// Caclulate pivot and flip.
pivot := hashValue % ind_count
Expand All @@ -157,8 +171,7 @@ func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte)
binary.LittleEndian.PutUint32(positionByteArray, uint32(position>>8))
input2 := append(seed[:], byte(i))
input2 = append(input2, positionByteArray...)

hashedInput2 := utils.Keccak256(input2)
hashedInput2 := hashFunc(input2)
// Read hash value.
byteVal := hashedInput2[(position%256)/8]
bitVal := (byteVal >> (position % 8)) % 2
Expand All @@ -169,18 +182,24 @@ func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte)
return ind, nil
}

func (b *BeaconState) ComputeCommittee(indicies []uint64, seed libcommon.Hash, index, count uint64) ([]uint64, error) {
ret := []uint64{}
func (b *BeaconState) ComputeCommittee(indicies []uint64, seed libcommon.Hash, index, count uint64, preInputs [][32]byte, hashFunc utils.HashFunc) ([]uint64, error) {
lenIndicies := uint64(len(indicies))
for i := (lenIndicies * index) / count; i < (lenIndicies*(index+1))/count; i++ {
index, err := b.ComputeShuffledIndex(i, lenIndicies, seed)
if err != nil {
return nil, err
start := (lenIndicies * index) / count
end := (lenIndicies * (index + 1)) / count
var shuffledIndicies []uint64
if shuffledIndicesInterface, ok := b.shuffledSetsCache.Get(seed); ok {
shuffledIndicies = shuffledIndicesInterface.([]uint64)
} else {
shuffledIndicies = make([]uint64, lenIndicies)
copy(shuffledIndicies, indicies)
eth2ShuffleHashFunc := func(data []byte) []byte {
hashed := hashFunc(data)
return hashed[:]
}
ret = append(ret, indicies[index])
eth2_shuffle.UnshuffleList(eth2ShuffleHashFunc, shuffledIndicies, uint8(b.beaconConfig.ShuffleRoundCount), seed)
b.shuffledSetsCache.Add(seed, shuffledIndicies)
}
return ret, nil
//return [indices[compute_shuffled_index(uint64(i), uint64(len(indices)), seed)] for i in range(start, end)]
return shuffledIndicies[start:end], nil
}

func (b *BeaconState) ComputeProposerIndex(indices []uint64, seed [32]byte) (uint64, error) {
Expand All @@ -191,8 +210,9 @@ func (b *BeaconState) ComputeProposerIndex(indices []uint64, seed [32]byte) (uin
i := uint64(0)
total := uint64(len(indices))
buf := make([]byte, 8)
preInputs := b.ComputeShuffledIndexPreInputs(seed)
for {
shuffled, err := b.ComputeShuffledIndex(i%total, total, seed)
shuffled, err := b.ComputeShuffledIndex(i%total, total, seed, preInputs, utils.Keccak256)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -348,21 +368,33 @@ func (b *BeaconState) GetAttestationParticipationFlagIndicies(data *cltypes.Atte
}

func (b *BeaconState) GetBeaconCommitee(slot, committeeIndex uint64) ([]uint64, error) {
var cacheKey [16]byte
binary.BigEndian.PutUint64(cacheKey[:], slot)
binary.BigEndian.PutUint64(cacheKey[8:], committeeIndex)
if cachedCommittee, ok := b.committeeCache.Get(cacheKey); ok {
return cachedCommittee.([]uint64), nil
}
epoch := b.GetEpochAtSlot(slot)
committeesPerSlot := b.CommitteeCount(epoch)
return b.ComputeCommittee(
seed := b.GetSeed(epoch, b.beaconConfig.DomainBeaconAttester)
preInputs := b.ComputeShuffledIndexPreInputs(seed)
hashFunc := utils.OptimizedKeccak256()
committee, err := b.ComputeCommittee(
b.GetActiveValidatorsIndices(epoch),
b.GetSeed(epoch, b.beaconConfig.DomainBeaconAttester),
seed,
(slot%b.beaconConfig.SlotsPerEpoch)*committeesPerSlot+committeeIndex,
committeesPerSlot*b.beaconConfig.SlotsPerEpoch,
preInputs,
hashFunc,
)
}

func (b *BeaconState) GetIndexedAttestation(attestation *cltypes.Attestation) (*cltypes.IndexedAttestation, error) {
attestingIndicies, err := b.GetAttestingIndicies(attestation.Data, attestation.AggregationBits)
if err != nil {
return nil, err
}
b.committeeCache.Add(cacheKey, committee)
return committee, nil
}

func (b *BeaconState) GetIndexedAttestation(attestation *cltypes.Attestation, attestingIndicies []uint64) (*cltypes.IndexedAttestation, error) {
// Sort the the attestation indicies.
sort.Slice(attestingIndicies, func(i, j int) bool {
return attestingIndicies[i] < attestingIndicies[j]
Expand Down
16 changes: 9 additions & 7 deletions cmd/erigon-cl/core/state/accessors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ func TestActiveValidatorIndices(t *testing.T) {
ActivationEpoch: 3,
ExitEpoch: 9,
EffectiveBalance: 2e9,
})
}, 2e9)
// Active Validator
testState.AddValidator(&cltypes.Validator{
ActivationEpoch: 1,
ExitEpoch: 9,
EffectiveBalance: 2e9,
})
}, 2e9)
testState.SetSlot(epoch * 32) // Epoch
testFlags := cltypes.ParticipationFlagsListFromBytes([]byte{1, 1})
testState.SetCurrentEpochParticipation(testFlags)
Expand Down Expand Up @@ -148,7 +148,8 @@ func TestComputeShuffledIndex(t *testing.T) {
t.Run(tc.description, func(t *testing.T) {
for i, val := range tc.startInds {
state := state.New(&clparams.MainnetBeaconConfig)
got, err := state.ComputeShuffledIndex(val, uint64(len(tc.startInds)), tc.seed)
preInputs := state.ComputeShuffledIndexPreInputs(tc.seed)
got, err := state.ComputeShuffledIndex(val, uint64(len(tc.startInds)), tc.seed, preInputs, utils.Keccak256)
// Non-failure case.
if err != nil {
t.Errorf("unexpected error: %v", err)
Expand All @@ -164,7 +165,7 @@ func TestComputeShuffledIndex(t *testing.T) {
func generateBeaconStateWithValidators(n int) *state.BeaconState {
b := state.GetEmptyBeaconState()
for i := 0; i < n; i++ {
b.AddValidator(&cltypes.Validator{EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance})
b.AddValidator(&cltypes.Validator{EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}, clparams.MainnetBeaconConfig.MaxEffectiveBalance)
}
return b
}
Expand Down Expand Up @@ -237,7 +238,7 @@ func TestComputeProposerIndex(t *testing.T) {

func TestSyncReward(t *testing.T) {
s := state.GetEmptyBeaconState()
s.AddValidator(&cltypes.Validator{EffectiveBalance: 3099999999909, ExitEpoch: 2})
s.AddValidator(&cltypes.Validator{EffectiveBalance: 3099999999909, ExitEpoch: 2}, 3099999999909)
propReward, partRew, err := s.SyncRewards()
require.NoError(t, err)
require.Equal(t, propReward, uint64(30))
Expand Down Expand Up @@ -265,12 +266,13 @@ func TestComputeCommittee(t *testing.T) {
epoch := state.Epoch()
indices := state.GetActiveValidatorsIndices(epoch)
seed := state.GetSeed(epoch, clparams.MainnetBeaconConfig.DomainBeaconAttester)
committees, err := state.ComputeCommittee(indices, seed, 0, 1)
preInputs := state.ComputeShuffledIndexPreInputs(seed)
committees, err := state.ComputeCommittee(indices, seed, 0, 1, preInputs, utils.Keccak256)
require.NoError(t, err, "Could not compute committee")

// Test shuffled indices are correct for index 5 committee
index := uint64(5)
committee5, err := state.ComputeCommittee(indices, seed, index, committeeCount)
committee5, err := state.ComputeCommittee(indices, seed, index, committeeCount, preInputs, utils.Keccak256)
require.NoError(t, err, "Could not compute committee")
start := (validatorCount * index) / committeeCount
end := (validatorCount * (index + 1)) / committeeCount
Expand Down
14 changes: 4 additions & 10 deletions cmd/erigon-cl/core/state/mutators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ const (

func getTestStateBalances(t *testing.T) *state.BeaconState {
numVals := uint64(2048)
balances := make([]uint64, numVals)
b := state.GetEmptyBeaconState()
for i := uint64(0); i < numVals; i++ {
balances[i] = i
b.AddValidator(&cltypes.Validator{ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch}, i)
}
b := state.GetEmptyBeaconState()
b.SetBalances(balances)
return b
}

Expand All @@ -45,9 +43,7 @@ func TestIncreaseBalance(t *testing.T) {
beforeBalance := state.Balances()[testInd]
state.IncreaseBalance(int(testInd), amount)
afterBalance := state.Balances()[testInd]
if afterBalance != beforeBalance+amount {
t.Errorf("unepected after balance: %d, before balance: %d, increase: %d", afterBalance, beforeBalance, amount)
}
require.Equal(t, afterBalance, beforeBalance+amount)
}

func TestDecreaseBalance(t *testing.T) {
Expand Down Expand Up @@ -82,9 +78,7 @@ func TestDecreaseBalance(t *testing.T) {
state := getTestStateBalances(t)
require.NoError(t, state.DecreaseBalance(testInd, tc.delta))
afterBalance := state.Balances()[testInd]
if afterBalance != tc.expectedBalance {
t.Errorf("unexpected resulting balance: got %d, want %d", afterBalance, tc.expectedBalance)
}
require.Equal(t, afterBalance, tc.expectedBalance)
})
}
}
Expand Down
21 changes: 15 additions & 6 deletions cmd/erigon-cl/core/state/setters.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func (b *BeaconState) SetGenesisValidatorsRoot(genesisValidatorRoot libcommon.Ha
func (b *BeaconState) SetSlot(slot uint64) {
b.touchedLeaves[SlotLeafIndex] = true
b.slot = slot
// If there is a new slot update the active balance cache.
b._refreshActiveBalances()
}

func (b *BeaconState) SetFork(fork *cltypes.Fork) {
Expand Down Expand Up @@ -61,6 +63,9 @@ func (b *BeaconState) SetValidatorAt(index int, validator *cltypes.Validator) er
return InvalidValidatorIndex
}
b.validators[index] = validator
// change in validator set means cache purging
b.activeValidatorsCache.Purge()
b._refreshActiveBalances()
return nil
}

Expand Down Expand Up @@ -91,26 +96,30 @@ func (b *BeaconState) SetValidators(validators []*cltypes.Validator) {
b.initBeaconState()
}

func (b *BeaconState) AddValidator(validator *cltypes.Validator) {
func (b *BeaconState) AddValidator(validator *cltypes.Validator, balance uint64) {
b.touchedLeaves[ValidatorsLeafIndex] = true
b.validators = append(b.validators, validator)
b.balances = append(b.balances, balance)
if validator.Active(b.Epoch()) {
b.totalActiveBalanceCache += validator.EffectiveBalance
}
b.publicKeyIndicies[validator.PublicKey] = uint64(len(b.validators)) - 1
// change in validator set means cache purging
b.activeValidatorsCache.Purge()

}

func (b *BeaconState) SetBalances(balances []uint64) {
b.touchedLeaves[BalancesLeafIndex] = true
b.balances = balances
}

func (b *BeaconState) AddBalance(balance uint64) {
b.touchedLeaves[BalancesLeafIndex] = true
b.balances = append(b.balances, balance)
b._refreshActiveBalances()
}

func (b *BeaconState) SetValidatorBalance(index int, balance uint64) error {
if index >= len(b.balances) {
return InvalidValidatorIndex
}

b.touchedLeaves[BalancesLeafIndex] = true
b.balances[index] = balance
return nil
Expand Down
Loading

0 comments on commit a585ae4

Please sign in to comment.