diff --git a/cl/cltypes/eth1_data.go b/cl/cltypes/eth1_data.go index 3f05be9a369..b60a1352fcc 100644 --- a/cl/cltypes/eth1_data.go +++ b/cl/cltypes/eth1_data.go @@ -14,6 +14,10 @@ type Eth1Data struct { DepositCount uint64 } +func (e *Eth1Data) Equal(b *Eth1Data) bool { + return e.BlockHash == b.BlockHash && e.Root == b.Root && b.DepositCount == e.DepositCount +} + // MarshalSSZTo ssz marshals the Eth1Data object to a target array func (e *Eth1Data) EncodeSSZ(buf []byte) (dst []byte, err error) { dst = buf diff --git a/cl/utils/bytes.go b/cl/utils/bytes.go index 71e2037485b..df9989260c3 100644 --- a/cl/utils/bytes.go +++ b/cl/utils/bytes.go @@ -98,3 +98,13 @@ func DecompressZstd(b []byte) ([]byte, error) { } return r.DecodeAll(b, nil) } + +// Check if it is sorted and check if there are duplicates. O(N) complexity. +func IsSliceSortedSet(vals []uint64) bool { + for i := 0; i < len(vals)-1; i++ { + if vals[i] >= vals[i+1] { + return false + } + } + return true +} diff --git a/cl/utils/crypto.go b/cl/utils/crypto.go index a103dba36ed..23104caf2fb 100644 --- a/cl/utils/crypto.go +++ b/cl/utils/crypto.go @@ -19,12 +19,15 @@ import ( "sync" ) +type HashFunc func(data []byte, extras ...[]byte) [32]byte + var hasherPool = sync.Pool{ New: func() interface{} { return sha256.New() }, } +// General purpose Keccak256 func Keccak256(data []byte, extras ...[]byte) [32]byte { h, ok := hasherPool.Get().(hash.Hash) if !ok { @@ -40,6 +43,20 @@ func Keccak256(data []byte, extras ...[]byte) [32]byte { h.Write(extra) } h.Sum(b[:0]) - return b } + +// Optimized Keccak256, avoid pool.put/pool.get, meant for intensive operations. +func OptimizedKeccak256() HashFunc { + h := sha256.New() + return func(data []byte, extras ...[]byte) [32]byte { + h.Reset() + var b [32]byte + h.Write(data) + for _, extra := range extras { + h.Write(extra) + } + h.Sum(b[:0]) + return b + } +} diff --git a/cmd/erigon-cl/core/state/accessors.go b/cmd/erigon-cl/core/state/accessors.go index bb419a0bd33..bee33a0a9aa 100644 --- a/cmd/erigon-cl/core/state/accessors.go +++ b/cmd/erigon-cl/core/state/accessors.go @@ -11,16 +11,21 @@ import ( "github.com/ledgerwatch/erigon/cl/cltypes" "github.com/ledgerwatch/erigon/cl/fork" "github.com/ledgerwatch/erigon/cl/utils" + eth2_shuffle "github.com/protolambda/eth2-shuffle" ) // GetActiveValidatorsIndices returns the list of validator indices active for the given epoch. func (b *BeaconState) GetActiveValidatorsIndices(epoch uint64) (indicies []uint64) { + if cachedIndicies, ok := b.activeValidatorsCache.Get(epoch); ok { + return cachedIndicies.([]uint64) + } for i, validator := range b.validators { if !validator.Active(epoch) { continue } indicies = append(indicies, uint64(i)) } + b.activeValidatorsCache.Add(epoch, indicies) return } @@ -89,7 +94,10 @@ func (b *BeaconState) GetTotalBalance(validatorSet []uint64) (uint64, error) { // GetTotalActiveBalance return the sum of all balances within active validators. func (b *BeaconState) GetTotalActiveBalance() (uint64, error) { - return b.GetTotalBalance(b.GetActiveValidatorsIndices(b.Epoch())) + if b.totalActiveBalanceCache < b.beaconConfig.EffectiveBalanceIncrement { + return b.beaconConfig.EffectiveBalanceIncrement, nil + } + return b.totalActiveBalanceCache, nil } // GetTotalSlashingAmount return the sum of all slashings. @@ -129,18 +137,24 @@ func (b *BeaconState) GetDomain(domainType [4]byte, epoch uint64) ([]byte, error return fork.ComputeDomain(domainType[:], forkVersion, b.genesisValidatorsRoot) } -func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte) (uint64, error) { +func (b *BeaconState) ComputeShuffledIndexPreInputs(seed [32]byte) [][32]byte { + ret := make([][32]byte, b.beaconConfig.ShuffleRoundCount) + for i := range ret { + ret[i] = utils.Keccak256(append(seed[:], byte(i))) + } + return ret +} + +func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte, preInputs [][32]byte, hashFunc utils.HashFunc) (uint64, error) { if ind >= ind_count { return 0, fmt.Errorf("index=%d must be less than the index count=%d", ind, ind_count) } - + if len(preInputs) == 0 { + preInputs = b.ComputeShuffledIndexPreInputs(seed) + } for i := uint64(0); i < b.beaconConfig.ShuffleRoundCount; i++ { - // Construct first hash input. - input := append(seed[:], byte(i)) - hashedInput := utils.Keccak256(input) - // Read hash value. - hashValue := binary.LittleEndian.Uint64(hashedInput[:8]) + hashValue := binary.LittleEndian.Uint64(preInputs[i][:8]) // Caclulate pivot and flip. pivot := hashValue % ind_count @@ -157,8 +171,7 @@ func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte) binary.LittleEndian.PutUint32(positionByteArray, uint32(position>>8)) input2 := append(seed[:], byte(i)) input2 = append(input2, positionByteArray...) - - hashedInput2 := utils.Keccak256(input2) + hashedInput2 := hashFunc(input2) // Read hash value. byteVal := hashedInput2[(position%256)/8] bitVal := (byteVal >> (position % 8)) % 2 @@ -169,18 +182,24 @@ func (b *BeaconState) ComputeShuffledIndex(ind, ind_count uint64, seed [32]byte) return ind, nil } -func (b *BeaconState) ComputeCommittee(indicies []uint64, seed libcommon.Hash, index, count uint64) ([]uint64, error) { - ret := []uint64{} +func (b *BeaconState) ComputeCommittee(indicies []uint64, seed libcommon.Hash, index, count uint64, preInputs [][32]byte, hashFunc utils.HashFunc) ([]uint64, error) { lenIndicies := uint64(len(indicies)) - for i := (lenIndicies * index) / count; i < (lenIndicies*(index+1))/count; i++ { - index, err := b.ComputeShuffledIndex(i, lenIndicies, seed) - if err != nil { - return nil, err + start := (lenIndicies * index) / count + end := (lenIndicies * (index + 1)) / count + var shuffledIndicies []uint64 + if shuffledIndicesInterface, ok := b.shuffledSetsCache.Get(seed); ok { + shuffledIndicies = shuffledIndicesInterface.([]uint64) + } else { + shuffledIndicies = make([]uint64, lenIndicies) + copy(shuffledIndicies, indicies) + eth2ShuffleHashFunc := func(data []byte) []byte { + hashed := hashFunc(data) + return hashed[:] } - ret = append(ret, indicies[index]) + eth2_shuffle.UnshuffleList(eth2ShuffleHashFunc, shuffledIndicies, uint8(b.beaconConfig.ShuffleRoundCount), seed) + b.shuffledSetsCache.Add(seed, shuffledIndicies) } - return ret, nil - //return [indices[compute_shuffled_index(uint64(i), uint64(len(indices)), seed)] for i in range(start, end)] + return shuffledIndicies[start:end], nil } func (b *BeaconState) ComputeProposerIndex(indices []uint64, seed [32]byte) (uint64, error) { @@ -191,8 +210,9 @@ func (b *BeaconState) ComputeProposerIndex(indices []uint64, seed [32]byte) (uin i := uint64(0) total := uint64(len(indices)) buf := make([]byte, 8) + preInputs := b.ComputeShuffledIndexPreInputs(seed) for { - shuffled, err := b.ComputeShuffledIndex(i%total, total, seed) + shuffled, err := b.ComputeShuffledIndex(i%total, total, seed, preInputs, utils.Keccak256) if err != nil { return 0, err } @@ -348,21 +368,33 @@ func (b *BeaconState) GetAttestationParticipationFlagIndicies(data *cltypes.Atte } func (b *BeaconState) GetBeaconCommitee(slot, committeeIndex uint64) ([]uint64, error) { + var cacheKey [16]byte + binary.BigEndian.PutUint64(cacheKey[:], slot) + binary.BigEndian.PutUint64(cacheKey[8:], committeeIndex) + if cachedCommittee, ok := b.committeeCache.Get(cacheKey); ok { + return cachedCommittee.([]uint64), nil + } epoch := b.GetEpochAtSlot(slot) committeesPerSlot := b.CommitteeCount(epoch) - return b.ComputeCommittee( + seed := b.GetSeed(epoch, b.beaconConfig.DomainBeaconAttester) + preInputs := b.ComputeShuffledIndexPreInputs(seed) + hashFunc := utils.OptimizedKeccak256() + committee, err := b.ComputeCommittee( b.GetActiveValidatorsIndices(epoch), - b.GetSeed(epoch, b.beaconConfig.DomainBeaconAttester), + seed, (slot%b.beaconConfig.SlotsPerEpoch)*committeesPerSlot+committeeIndex, committeesPerSlot*b.beaconConfig.SlotsPerEpoch, + preInputs, + hashFunc, ) -} - -func (b *BeaconState) GetIndexedAttestation(attestation *cltypes.Attestation) (*cltypes.IndexedAttestation, error) { - attestingIndicies, err := b.GetAttestingIndicies(attestation.Data, attestation.AggregationBits) if err != nil { return nil, err } + b.committeeCache.Add(cacheKey, committee) + return committee, nil +} + +func (b *BeaconState) GetIndexedAttestation(attestation *cltypes.Attestation, attestingIndicies []uint64) (*cltypes.IndexedAttestation, error) { // Sort the the attestation indicies. sort.Slice(attestingIndicies, func(i, j int) bool { return attestingIndicies[i] < attestingIndicies[j] diff --git a/cmd/erigon-cl/core/state/accessors_test.go b/cmd/erigon-cl/core/state/accessors_test.go index 8a497400f4f..4437f2de00e 100644 --- a/cmd/erigon-cl/core/state/accessors_test.go +++ b/cmd/erigon-cl/core/state/accessors_test.go @@ -41,13 +41,13 @@ func TestActiveValidatorIndices(t *testing.T) { ActivationEpoch: 3, ExitEpoch: 9, EffectiveBalance: 2e9, - }) + }, 2e9) // Active Validator testState.AddValidator(&cltypes.Validator{ ActivationEpoch: 1, ExitEpoch: 9, EffectiveBalance: 2e9, - }) + }, 2e9) testState.SetSlot(epoch * 32) // Epoch testFlags := cltypes.ParticipationFlagsListFromBytes([]byte{1, 1}) testState.SetCurrentEpochParticipation(testFlags) @@ -148,7 +148,8 @@ func TestComputeShuffledIndex(t *testing.T) { t.Run(tc.description, func(t *testing.T) { for i, val := range tc.startInds { state := state.New(&clparams.MainnetBeaconConfig) - got, err := state.ComputeShuffledIndex(val, uint64(len(tc.startInds)), tc.seed) + preInputs := state.ComputeShuffledIndexPreInputs(tc.seed) + got, err := state.ComputeShuffledIndex(val, uint64(len(tc.startInds)), tc.seed, preInputs, utils.Keccak256) // Non-failure case. if err != nil { t.Errorf("unexpected error: %v", err) @@ -164,7 +165,7 @@ func TestComputeShuffledIndex(t *testing.T) { func generateBeaconStateWithValidators(n int) *state.BeaconState { b := state.GetEmptyBeaconState() for i := 0; i < n; i++ { - b.AddValidator(&cltypes.Validator{EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}) + b.AddValidator(&cltypes.Validator{EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}, clparams.MainnetBeaconConfig.MaxEffectiveBalance) } return b } @@ -237,7 +238,7 @@ func TestComputeProposerIndex(t *testing.T) { func TestSyncReward(t *testing.T) { s := state.GetEmptyBeaconState() - s.AddValidator(&cltypes.Validator{EffectiveBalance: 3099999999909, ExitEpoch: 2}) + s.AddValidator(&cltypes.Validator{EffectiveBalance: 3099999999909, ExitEpoch: 2}, 3099999999909) propReward, partRew, err := s.SyncRewards() require.NoError(t, err) require.Equal(t, propReward, uint64(30)) @@ -265,12 +266,13 @@ func TestComputeCommittee(t *testing.T) { epoch := state.Epoch() indices := state.GetActiveValidatorsIndices(epoch) seed := state.GetSeed(epoch, clparams.MainnetBeaconConfig.DomainBeaconAttester) - committees, err := state.ComputeCommittee(indices, seed, 0, 1) + preInputs := state.ComputeShuffledIndexPreInputs(seed) + committees, err := state.ComputeCommittee(indices, seed, 0, 1, preInputs, utils.Keccak256) require.NoError(t, err, "Could not compute committee") // Test shuffled indices are correct for index 5 committee index := uint64(5) - committee5, err := state.ComputeCommittee(indices, seed, index, committeeCount) + committee5, err := state.ComputeCommittee(indices, seed, index, committeeCount, preInputs, utils.Keccak256) require.NoError(t, err, "Could not compute committee") start := (validatorCount * index) / committeeCount end := (validatorCount * (index + 1)) / committeeCount diff --git a/cmd/erigon-cl/core/state/mutators_test.go b/cmd/erigon-cl/core/state/mutators_test.go index 81fab218819..add6317ed6f 100644 --- a/cmd/erigon-cl/core/state/mutators_test.go +++ b/cmd/erigon-cl/core/state/mutators_test.go @@ -15,12 +15,10 @@ const ( func getTestStateBalances(t *testing.T) *state.BeaconState { numVals := uint64(2048) - balances := make([]uint64, numVals) + b := state.GetEmptyBeaconState() for i := uint64(0); i < numVals; i++ { - balances[i] = i + b.AddValidator(&cltypes.Validator{ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch}, i) } - b := state.GetEmptyBeaconState() - b.SetBalances(balances) return b } @@ -45,9 +43,7 @@ func TestIncreaseBalance(t *testing.T) { beforeBalance := state.Balances()[testInd] state.IncreaseBalance(int(testInd), amount) afterBalance := state.Balances()[testInd] - if afterBalance != beforeBalance+amount { - t.Errorf("unepected after balance: %d, before balance: %d, increase: %d", afterBalance, beforeBalance, amount) - } + require.Equal(t, afterBalance, beforeBalance+amount) } func TestDecreaseBalance(t *testing.T) { @@ -82,9 +78,7 @@ func TestDecreaseBalance(t *testing.T) { state := getTestStateBalances(t) require.NoError(t, state.DecreaseBalance(testInd, tc.delta)) afterBalance := state.Balances()[testInd] - if afterBalance != tc.expectedBalance { - t.Errorf("unexpected resulting balance: got %d, want %d", afterBalance, tc.expectedBalance) - } + require.Equal(t, afterBalance, tc.expectedBalance) }) } } diff --git a/cmd/erigon-cl/core/state/setters.go b/cmd/erigon-cl/core/state/setters.go index 2b771eb866c..d2f23a4da84 100644 --- a/cmd/erigon-cl/core/state/setters.go +++ b/cmd/erigon-cl/core/state/setters.go @@ -24,6 +24,8 @@ func (b *BeaconState) SetGenesisValidatorsRoot(genesisValidatorRoot libcommon.Ha func (b *BeaconState) SetSlot(slot uint64) { b.touchedLeaves[SlotLeafIndex] = true b.slot = slot + // If there is a new slot update the active balance cache. + b._refreshActiveBalances() } func (b *BeaconState) SetFork(fork *cltypes.Fork) { @@ -61,6 +63,9 @@ func (b *BeaconState) SetValidatorAt(index int, validator *cltypes.Validator) er return InvalidValidatorIndex } b.validators[index] = validator + // change in validator set means cache purging + b.activeValidatorsCache.Purge() + b._refreshActiveBalances() return nil } @@ -91,26 +96,30 @@ func (b *BeaconState) SetValidators(validators []*cltypes.Validator) { b.initBeaconState() } -func (b *BeaconState) AddValidator(validator *cltypes.Validator) { +func (b *BeaconState) AddValidator(validator *cltypes.Validator, balance uint64) { b.touchedLeaves[ValidatorsLeafIndex] = true b.validators = append(b.validators, validator) + b.balances = append(b.balances, balance) + if validator.Active(b.Epoch()) { + b.totalActiveBalanceCache += validator.EffectiveBalance + } b.publicKeyIndicies[validator.PublicKey] = uint64(len(b.validators)) - 1 + // change in validator set means cache purging + b.activeValidatorsCache.Purge() + } func (b *BeaconState) SetBalances(balances []uint64) { b.touchedLeaves[BalancesLeafIndex] = true b.balances = balances -} - -func (b *BeaconState) AddBalance(balance uint64) { - b.touchedLeaves[BalancesLeafIndex] = true - b.balances = append(b.balances, balance) + b._refreshActiveBalances() } func (b *BeaconState) SetValidatorBalance(index int, balance uint64) error { if index >= len(b.balances) { return InvalidValidatorIndex } + b.touchedLeaves[BalancesLeafIndex] = true b.balances[index] = balance return nil diff --git a/cmd/erigon-cl/core/state/state.go b/cmd/erigon-cl/core/state/state.go index a77807d6931..63425f75342 100644 --- a/cmd/erigon-cl/core/state/state.go +++ b/cmd/erigon-cl/core/state/state.go @@ -1,6 +1,7 @@ package state import ( + lru "github.com/hashicorp/golang-lru" libcommon "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon/cl/clparams" @@ -55,6 +56,11 @@ type BeaconState struct { leaves [32][32]byte // Pre-computed leaves. touchedLeaves map[StateLeafIndex]bool // Maps each leaf to whether they were touched or not. publicKeyIndicies map[[48]byte]uint64 + // Caches + activeValidatorsCache *lru.Cache + committeeCache *lru.Cache + shuffledSetsCache *lru.Cache + totalActiveBalanceCache uint64 // Configs beaconConfig *clparams.BeaconChainConfig } @@ -95,10 +101,37 @@ func (b *BeaconState) BlockRoot() ([32]byte, error) { }).HashSSZ() } +func (b *BeaconState) _refreshActiveBalances() { + epoch := b.Epoch() + b.totalActiveBalanceCache = 0 + for _, validator := range b.validators { + if validator.Active(epoch) { + b.totalActiveBalanceCache += validator.EffectiveBalance + } + } +} + func (b *BeaconState) initBeaconState() { - b.touchedLeaves = make(map[StateLeafIndex]bool) + if b.touchedLeaves == nil { + b.touchedLeaves = make(map[StateLeafIndex]bool) + } b.publicKeyIndicies = make(map[[48]byte]uint64) + b._refreshActiveBalances() for i, validator := range b.validators { b.publicKeyIndicies[validator.PublicKey] = uint64(i) } + // 5 Epochs at a time is reasonable. + var err error + b.activeValidatorsCache, err = lru.New(5) + if err != nil { + panic(err) + } + b.shuffledSetsCache, err = lru.New(25) + if err != nil { + panic(err) + } + b.committeeCache, err = lru.New(256) + if err != nil { + panic(err) + } } diff --git a/cmd/erigon-cl/core/transition/block_transition.go b/cmd/erigon-cl/core/transition/block_transition.go index e622869e306..74e1e8111aa 100644 --- a/cmd/erigon-cl/core/transition/block_transition.go +++ b/cmd/erigon-cl/core/transition/block_transition.go @@ -21,7 +21,7 @@ func (s *StateTransistor) processBlock(signedBlock *cltypes.SignedBeaconBlock) e // Set execution header accordingly to state. s.state.SetLatestExecutionPayloadHeader(block.Body.ExecutionPayload.Header) } - if err := s.ProcessRandao(block.Body.RandaoReveal); err != nil { + if err := s.ProcessRandao(block.Body.RandaoReveal, block.ProposerIndex); err != nil { return fmt.Errorf("ProcessRandao: %s", err) } if err := s.ProcessEth1Data(block.Body.Eth1Data); err != nil { @@ -57,10 +57,8 @@ func (s *StateTransistor) processOperations(blockBody *cltypes.BeaconBody) error } } // Process each attestations - for _, att := range blockBody.Attestations { - if err := s.ProcessAttestation(att); err != nil { - return fmt.Errorf("ProcessAttestation: %s", err) - } + if err := s.ProcessAttestations(blockBody.Attestations); err != nil { + return fmt.Errorf("ProcessAttestation: %s", err) } // Process each deposit for _, dep := range blockBody.Deposits { diff --git a/cmd/erigon-cl/core/transition/finalization_and_justification.go b/cmd/erigon-cl/core/transition/finalization_and_justification.go index 36b6f0bd27c..d91d523e109 100644 --- a/cmd/erigon-cl/core/transition/finalization_and_justification.go +++ b/cmd/erigon-cl/core/transition/finalization_and_justification.go @@ -1,6 +1,8 @@ package transition import ( + "fmt" + "github.com/ledgerwatch/erigon/cl/clparams" "github.com/ledgerwatch/erigon/cl/cltypes" ) @@ -89,6 +91,7 @@ func (s *StateTransistor) processJustificationBitsAndFinalityAltair() error { if err != nil { return err } + fmt.Println(totalActiveBalance > 1000000000) previousTargetBalance, err := s.state.GetTotalBalance(previousIndices) if err != nil { return err diff --git a/cmd/erigon-cl/core/transition/finalization_and_justification_test.go b/cmd/erigon-cl/core/transition/finalization_and_justification_test.go index 3e5927cd893..ab214998ba4 100644 --- a/cmd/erigon-cl/core/transition/finalization_and_justification_test.go +++ b/cmd/erigon-cl/core/transition/finalization_and_justification_test.go @@ -25,10 +25,10 @@ func getJustificationAndFinalizationState() *state.BeaconState { bits := cltypes.JustificationBits{} bits.FromByte(0x3) state.SetJustificationBits(bits) - state.SetValidators([]*cltypes.Validator{ - {ExitEpoch: epoch}, {ExitEpoch: epoch}, {ExitEpoch: epoch}, {ExitEpoch: epoch}, - }) - state.SetBalances([]uint64{bal, bal, bal, bal}) + state.AddValidator(&cltypes.Validator{ExitEpoch: epoch}, bal) + state.AddValidator(&cltypes.Validator{ExitEpoch: epoch}, bal) + state.AddValidator(&cltypes.Validator{ExitEpoch: epoch}, bal) + state.AddValidator(&cltypes.Validator{ExitEpoch: epoch}, bal) state.SetCurrentEpochParticipation(cltypes.ParticipationFlagsList{0b01, 0b01, 0b01, 0b01}) state.SetPreviousEpochParticipation(cltypes.ParticipationFlagsList{0b01, 0b01, 0b01, 0b01}) diff --git a/cmd/erigon-cl/core/transition/operations.go b/cmd/erigon-cl/core/transition/operations.go index 2470f0de3de..f300a57f2d7 100644 --- a/cmd/erigon-cl/core/transition/operations.go +++ b/cmd/erigon-cl/core/transition/operations.go @@ -10,7 +10,6 @@ import ( "github.com/ledgerwatch/erigon/cl/fork" "github.com/ledgerwatch/erigon/cl/utils" "github.com/ledgerwatch/erigon/cmd/erigon-cl/core/state" - "golang.org/x/exp/slices" ) func IsSlashableValidator(validator *cltypes.Validator, epoch uint64) bool { @@ -31,15 +30,6 @@ func IsSlashableAttestationData(d1, d2 *cltypes.AttestationData) (bool, error) { return (hash1 != hash2 && d1.Target.Epoch == d2.Target.Epoch) || (d1.Source.Epoch < d2.Source.Epoch && d2.Target.Epoch < d1.Target.Epoch), nil } -func IsSortedSet(vals []uint64) bool { - for i := 0; i < len(vals)-1; i++ { - if vals[i] >= vals[i+1] { - return false - } - } - return true -} - func GetSetIntersection(v1, v2 []uint64) []uint64 { intersection := []uint64{} present := map[uint64]bool{} @@ -58,7 +48,7 @@ func GetSetIntersection(v1, v2 []uint64) []uint64 { func isValidIndexedAttestation(state *state.BeaconState, att *cltypes.IndexedAttestation) (bool, error) { inds := att.AttestingIndices - if len(inds) == 0 || !IsSortedSet(inds) { + if len(inds) == 0 || !utils.IsSliceSortedSet(inds) { return false, fmt.Errorf("isValidIndexedAttestation: attesting indices are not sorted or are null") } @@ -241,8 +231,7 @@ func (s *StateTransistor) ProcessDeposit(deposit *cltypes.Deposit) error { } if valid { // Append validator - s.state.AddValidator(s.state.ValidatorFromDeposit(deposit)) - s.state.AddBalance(amount) + s.state.AddValidator(s.state.ValidatorFromDeposit(deposit), amount) // Altair only s.state.AddCurrentEpochParticipationFlags(cltypes.ParticipationFlags(0)) s.state.AddPreviousEpochParticipationFlags(cltypes.ParticipationFlags(0)) @@ -298,92 +287,3 @@ func (s *StateTransistor) ProcessVoluntaryExit(signedVoluntaryExit *cltypes.Sign // Do the exit (same process in slashing). return s.state.InitiateValidatorExit(voluntaryExit.ValidatorIndex) } - -// ProcessVoluntaryExit takes a voluntary exit and applies state transition. -func (s *StateTransistor) ProcessAttestation(attestation *cltypes.Attestation) error { - participationFlagWeights := []uint64{ - s.beaconConfig.TimelySourceWeight, - s.beaconConfig.TimelyTargetWeight, - s.beaconConfig.TimelyHeadWeight, - } - - totalActiveBalance, err := s.state.GetTotalActiveBalance() - if err != nil { - return err - } - data := attestation.Data - currentEpoch := s.state.Epoch() - previousEpoch := s.state.PreviousEpoch() - stateSlot := s.state.Slot() - if (data.Target.Epoch != currentEpoch && data.Target.Epoch != previousEpoch) || data.Target.Epoch != s.state.GetEpochAtSlot(data.Slot) { - return errors.New("ProcessAttestation: attestation with invalid epoch") - } - if data.Slot+s.beaconConfig.MinAttestationInclusionDelay > stateSlot || stateSlot > data.Slot+s.beaconConfig.SlotsPerEpoch { - return errors.New("ProcessAttestation: attestation slot not in range") - } - if data.Index >= s.state.CommitteeCount(data.Target.Epoch) { - return errors.New("ProcessAttestation: attester index out of range") - } - participationFlagsIndicies, err := s.state.GetAttestationParticipationFlagIndicies(attestation.Data, stateSlot-data.Slot) - if err != nil { - return err - } - valid, err := s.verifyAttestation(attestation) - if err != nil { - return err - } - if !valid { - return errors.New("ProcessAttestation: wrong bls data") - } - var epochParticipation cltypes.ParticipationFlagsList - if data.Target.Epoch == currentEpoch { - epochParticipation = s.state.CurrentEpochParticipation() - } else { - epochParticipation = s.state.PreviousEpochParticipation() - } - - var proposerRewardNumerator uint64 - attestingIndicies, err := s.state.GetAttestingIndicies(attestation.Data, attestation.AggregationBits) - if err != nil { - return err - } - - for _, attesterIndex := range attestingIndicies { - for flagIndex, weight := range participationFlagWeights { - if !slices.Contains(participationFlagsIndicies, uint8(flagIndex)) || epochParticipation[attesterIndex].HasFlag(flagIndex) { - continue - } - epochParticipation[attesterIndex] = epochParticipation[attesterIndex].Add(flagIndex) - baseReward, err := s.state.BaseReward(totalActiveBalance, attesterIndex) - if err != nil { - return err - } - proposerRewardNumerator += baseReward * weight - } - } - // Reward proposer - proposer, err := s.state.GetBeaconProposerIndex() - if err != nil { - return err - } - // Set participation - if data.Target.Epoch == currentEpoch { - s.state.SetCurrentEpochParticipation(epochParticipation) - } else { - s.state.SetPreviousEpochParticipation(epochParticipation) - } - proposerRewardDenominator := (s.beaconConfig.WeightDenominator - s.beaconConfig.ProposerWeight) * s.beaconConfig.WeightDenominator / s.beaconConfig.ProposerWeight - reward := proposerRewardNumerator / proposerRewardDenominator - return s.state.IncreaseBalance(int(proposer), reward) -} - -func (s *StateTransistor) verifyAttestation(attestation *cltypes.Attestation) (bool, error) { - if s.noValidate { - return true, nil - } - indexedAttestation, err := s.state.GetIndexedAttestation(attestation) - if err != nil { - return false, err - } - return isValidIndexedAttestation(s.state, indexedAttestation) -} diff --git a/cmd/erigon-cl/core/transition/operations_test.go b/cmd/erigon-cl/core/transition/operations_test.go index 4c09475c720..b409e3c36b9 100644 --- a/cmd/erigon-cl/core/transition/operations_test.go +++ b/cmd/erigon-cl/core/transition/operations_test.go @@ -325,32 +325,14 @@ func TestProcessDeposit(t *testing.T) { }, } testState := state.GetEmptyBeaconState() - testState.SetBalances([]uint64{0}) testState.AddValidator(&cltypes.Validator{ PublicKey: [48]byte{1}, WithdrawalCredentials: [32]byte{1, 2, 3}, - }) + }, 0) testState.SetEth1Data(eth1Data) s := New(testState, &clparams.MainnetBeaconConfig, nil, true) require.NoError(t, s.ProcessDeposit(deposit)) - if testState.Balances()[1] != deposit.Data.Amount { - t.Errorf( - "Expected state validator balances index 0 to equal %d, received %d", - deposit.Data.Amount, - testState.Balances()[1], - ) - } - /* - beaconState, err := state_native.InitializeFromProtoAltair(ðpb.BeaconStateAltair{ - Validators: registry, - Balances: balances, - Eth1Data: eth1Data, - Fork: ðpb.Fork{ - PreviousVersion: params.BeaconConfig().GenesisForkVersion, - CurrentVersion: params.BeaconConfig().GenesisForkVersion, - }, - })*/ - //s := New() + require.Equal(t, deposit.Data.Amount, testState.Balances()[1]) } func TestProcessVoluntaryExits(t *testing.T) { @@ -364,7 +346,7 @@ func TestProcessVoluntaryExits(t *testing.T) { state.AddValidator(&cltypes.Validator{ ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch, ActivationEpoch: 0, - }) + }, 0) state.SetSlot((clparams.MainnetBeaconConfig.SlotsPerEpoch * 5) + (clparams.MainnetBeaconConfig.SlotsPerEpoch * clparams.MainnetBeaconConfig.ShardCommitteePeriod)) transitioner := New(state, &clparams.MainnetBeaconConfig, nil, true) @@ -381,9 +363,8 @@ func TestProcessAttestation(t *testing.T) { EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance, ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch, WithdrawableEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch, - }) + }, clparams.MainnetBeaconConfig.MaxEffectiveBalance) beaconState.AddCurrentEpochParticipationFlags(cltypes.ParticipationFlags(0)) - beaconState.AddBalance(clparams.MainnetBeaconConfig.MaxEffectiveBalance) } aggBits := []byte{7} @@ -399,7 +380,7 @@ func TestProcessAttestation(t *testing.T) { } s := New(beaconState, &clparams.MainnetBeaconConfig, nil, true) - require.NoError(t, s.ProcessAttestation(att)) + require.NoError(t, s.ProcessAttestations([]*cltypes.Attestation{att})) p := beaconState.CurrentEpochParticipation() require.NoError(t, err) diff --git a/cmd/erigon-cl/core/transition/process_attestations.go b/cmd/erigon-cl/core/transition/process_attestations.go new file mode 100644 index 00000000000..ad6e657e60e --- /dev/null +++ b/cmd/erigon-cl/core/transition/process_attestations.go @@ -0,0 +1,136 @@ +package transition + +import ( + "errors" + + "github.com/ledgerwatch/erigon/cl/cltypes" + "github.com/ledgerwatch/erigon/cmd/erigon-cl/core/state" + "golang.org/x/exp/slices" +) + +func (s *StateTransistor) ProcessAttestations(attestations []*cltypes.Attestation) error { + var err error + attestingIndiciesSet := make([][]uint64, len(attestations)) + for i, attestation := range attestations { + if attestingIndiciesSet[i], err = s.processAttestation(attestation); err != nil { + return err + } + } + valid, err := s.verifyAttestations(attestations, attestingIndiciesSet) + if err != nil { + return err + } + if !valid { + return errors.New("ProcessAttestation: wrong bls data") + } + return nil +} + +// ProcessAttestation takes an attestation and process it. +func (s *StateTransistor) processAttestation(attestation *cltypes.Attestation) ([]uint64, error) { + participationFlagWeights := []uint64{ + s.beaconConfig.TimelySourceWeight, + s.beaconConfig.TimelyTargetWeight, + s.beaconConfig.TimelyHeadWeight, + } + + totalActiveBalance, err := s.state.GetTotalActiveBalance() + if err != nil { + return nil, err + } + data := attestation.Data + currentEpoch := s.state.Epoch() + previousEpoch := s.state.PreviousEpoch() + stateSlot := s.state.Slot() + if (data.Target.Epoch != currentEpoch && data.Target.Epoch != previousEpoch) || data.Target.Epoch != s.state.GetEpochAtSlot(data.Slot) { + return nil, errors.New("ProcessAttestation: attestation with invalid epoch") + } + if data.Slot+s.beaconConfig.MinAttestationInclusionDelay > stateSlot || stateSlot > data.Slot+s.beaconConfig.SlotsPerEpoch { + return nil, errors.New("ProcessAttestation: attestation slot not in range") + } + if data.Index >= s.state.CommitteeCount(data.Target.Epoch) { + return nil, errors.New("ProcessAttestation: attester index out of range") + } + participationFlagsIndicies, err := s.state.GetAttestationParticipationFlagIndicies(attestation.Data, stateSlot-data.Slot) + if err != nil { + return nil, err + } + + attestingIndicies, err := s.state.GetAttestingIndicies(attestation.Data, attestation.AggregationBits) + if err != nil { + return nil, err + } + var proposerRewardNumerator uint64 + + var epochParticipation cltypes.ParticipationFlagsList + if data.Target.Epoch == currentEpoch { + epochParticipation = s.state.CurrentEpochParticipation() + } else { + epochParticipation = s.state.PreviousEpochParticipation() + } + + for _, attesterIndex := range attestingIndicies { + for flagIndex, weight := range participationFlagWeights { + if !slices.Contains(participationFlagsIndicies, uint8(flagIndex)) || epochParticipation[attesterIndex].HasFlag(flagIndex) { + continue + } + epochParticipation[attesterIndex] = epochParticipation[attesterIndex].Add(flagIndex) + baseReward, err := s.state.BaseReward(totalActiveBalance, attesterIndex) + if err != nil { + return nil, err + } + proposerRewardNumerator += baseReward * weight + } + } + // Reward proposer + proposer, err := s.state.GetBeaconProposerIndex() + if err != nil { + return nil, err + } + // Set participation + if data.Target.Epoch == currentEpoch { + s.state.SetCurrentEpochParticipation(epochParticipation) + } else { + s.state.SetPreviousEpochParticipation(epochParticipation) + } + proposerRewardDenominator := (s.beaconConfig.WeightDenominator - s.beaconConfig.ProposerWeight) * s.beaconConfig.WeightDenominator / s.beaconConfig.ProposerWeight + reward := proposerRewardNumerator / proposerRewardDenominator + return attestingIndicies, s.state.IncreaseBalance(int(proposer), reward) +} + +type verifyAttestationWorkersResult struct { + success bool + err error +} + +func verifyAttestationWorker(state *state.BeaconState, attestation *cltypes.Attestation, attestingIndicies []uint64, resultCh chan verifyAttestationWorkersResult) { + indexedAttestation, err := state.GetIndexedAttestation(attestation, attestingIndicies) + if err != nil { + resultCh <- verifyAttestationWorkersResult{err: err} + return + } + success, err := isValidIndexedAttestation(state, indexedAttestation) + resultCh <- verifyAttestationWorkersResult{success: success, err: err} +} + +func (s *StateTransistor) verifyAttestations(attestations []*cltypes.Attestation, attestingIndicies [][]uint64) (bool, error) { + if s.noValidate { + return true, nil + } + resultCh := make(chan verifyAttestationWorkersResult, len(attestations)) + + for i, attestation := range attestations { + go verifyAttestationWorker(s.state, attestation, attestingIndicies[i], resultCh) + } + for i := 0; i < len(attestations); i++ { + result := <-resultCh + if result.err != nil { + return false, result.err + } + if !result.success { + return false, nil + } + } + close(resultCh) + return true, nil +} diff --git a/cmd/erigon-cl/core/transition/process_slashings_test.go b/cmd/erigon-cl/core/transition/process_slashings_test.go index 3a9ba2cda8e..aaff9aa0fe0 100644 --- a/cmd/erigon-cl/core/transition/process_slashings_test.go +++ b/cmd/erigon-cl/core/transition/process_slashings_test.go @@ -16,8 +16,7 @@ func TestProcessSlashingsNoSlash(t *testing.T) { base := state.GetEmptyBeaconStateWithVersion(clparams.AltairVersion) base.AddValidator(&cltypes.Validator{ Slashed: true, - }) - base.AddBalance(clparams.MainnetBeaconConfig.MaxEffectiveBalance) + }, clparams.MainnetBeaconConfig.MaxEffectiveBalance) base.SetSlashingSegmentAt(0, 0) base.SetSlashingSegmentAt(1, 1e9) s := transition.New(base, &clparams.MainnetBeaconConfig, nil, false) @@ -30,8 +29,8 @@ func getTestStateSlashings1() *state.BeaconState { state := state.GetEmptyBeaconStateWithVersion(clparams.AltairVersion) state.AddValidator(&cltypes.Validator{Slashed: true, WithdrawableEpoch: clparams.MainnetBeaconConfig.EpochsPerSlashingsVector / 2, - EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}) - state.AddValidator(&cltypes.Validator{ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch, EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}) + EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}, clparams.MainnetBeaconConfig.MaxEffectiveBalance) + state.AddValidator(&cltypes.Validator{ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch, EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}, clparams.MainnetBeaconConfig.MaxEffectiveBalance) state.SetBalances([]uint64{clparams.MainnetBeaconConfig.MaxEffectiveBalance, clparams.MainnetBeaconConfig.MaxEffectiveBalance}) state.SetSlashingSegmentAt(0, 0) state.SetSlashingSegmentAt(1, 1e9) @@ -40,7 +39,7 @@ func getTestStateSlashings1() *state.BeaconState { func getTestStateSlashings2() *state.BeaconState { state := getTestStateSlashings1() - state.AddValidator(&cltypes.Validator{ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch, EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}) + state.AddValidator(&cltypes.Validator{ExitEpoch: clparams.MainnetBeaconConfig.FarFutureEpoch, EffectiveBalance: clparams.MainnetBeaconConfig.MaxEffectiveBalance}, clparams.MainnetBeaconConfig.MaxEffectiveBalance) return state } diff --git a/cmd/erigon-cl/core/transition/process_slots.go b/cmd/erigon-cl/core/transition/process_slots.go index a92bf249695..6cfffe386af 100644 --- a/cmd/erigon-cl/core/transition/process_slots.go +++ b/cmd/erigon-cl/core/transition/process_slots.go @@ -25,7 +25,7 @@ func (s *StateTransistor) TransitionState(block *cltypes.SignedBeaconBlock) erro if err := s.processBlock(block); err != nil { return err } - // TODO add logic to process block and update state. + if !s.noValidate { expectedStateRoot, err := s.state.HashSSZ() if err != nil { diff --git a/cmd/erigon-cl/core/transition/processing.go b/cmd/erigon-cl/core/transition/processing.go index 851ed0908e0..3b0b6c4179e 100644 --- a/cmd/erigon-cl/core/transition/processing.go +++ b/cmd/erigon-cl/core/transition/processing.go @@ -18,26 +18,29 @@ func computeSigningRootEpoch(epoch uint64, domain []byte) (libcommon.Hash, error } func (s *StateTransistor) ProcessBlockHeader(block *cltypes.BeaconBlock) error { - if block.Slot != s.state.Slot() { - return fmt.Errorf("state slot: %d, not equal to block slot: %d", s.state.Slot(), block.Slot) - } - if block.Slot <= s.state.LatestBlockHeader().Slot { - return fmt.Errorf("slock slot: %d, not greater than latest block slot: %d", block.Slot, s.state.LatestBlockHeader().Slot) - } - propInd, err := s.state.GetBeaconProposerIndex() - if err != nil { - return fmt.Errorf("error in GetBeaconProposerIndex: %v", err) - } - if block.ProposerIndex != propInd { - return fmt.Errorf("block proposer index: %d, does not match beacon proposer index: %d", block.ProposerIndex, propInd) - } - latestRoot, err := s.state.LatestBlockHeader().HashSSZ() - if err != nil { - return fmt.Errorf("unable to hash tree root of latest block header: %v", err) - } - if block.ParentRoot != latestRoot { - return fmt.Errorf("block parent root: %x, does not match latest block root: %x", block.ParentRoot, latestRoot) + if !s.noValidate { + if block.Slot != s.state.Slot() { + return fmt.Errorf("state slot: %d, not equal to block slot: %d", s.state.Slot(), block.Slot) + } + if block.Slot <= s.state.LatestBlockHeader().Slot { + return fmt.Errorf("slock slot: %d, not greater than latest block slot: %d", block.Slot, s.state.LatestBlockHeader().Slot) + } + propInd, err := s.state.GetBeaconProposerIndex() + if err != nil { + return fmt.Errorf("error in GetBeaconProposerIndex: %v", err) + } + if block.ProposerIndex != propInd { + return fmt.Errorf("block proposer index: %d, does not match beacon proposer index: %d", block.ProposerIndex, propInd) + } + latestRoot, err := s.state.LatestBlockHeader().HashSSZ() + if err != nil { + return fmt.Errorf("unable to hash tree root of latest block header: %v", err) + } + if block.ParentRoot != latestRoot { + return fmt.Errorf("block parent root: %x, does not match latest block root: %x", block.ParentRoot, latestRoot) + } } + bodyRoot, err := block.Body.HashSSZ() if err != nil { return fmt.Errorf("unable to hash tree root of block body: %v", err) @@ -59,13 +62,9 @@ func (s *StateTransistor) ProcessBlockHeader(block *cltypes.BeaconBlock) error { return nil } -func (s *StateTransistor) ProcessRandao(randao [96]byte) error { +func (s *StateTransistor) ProcessRandao(randao [96]byte, proposerIndex uint64) error { epoch := s.state.Epoch() - propInd, err := s.state.GetBeaconProposerIndex() - if err != nil { - return fmt.Errorf("unable to get proposer index: %v", err) - } - proposer, err := s.state.ValidatorAt(int(propInd)) + proposer, err := s.state.ValidatorAt(int(proposerIndex)) if err != nil { return err } @@ -102,25 +101,10 @@ func (s *StateTransistor) ProcessEth1Data(eth1Data *cltypes.Eth1Data) error { s.state.AddEth1DataVote(eth1Data) newVotes := s.state.Eth1DataVotes() - ethDataHash, err := eth1Data.HashSSZ() - if err != nil { - return fmt.Errorf("unable to get hash tree root of eth1data: %v", err) - } - // Count how many times body.Eth1Data appears in the votes by comparing their hashes. + // Count how many times body.Eth1Data appears in the votes. numVotes := 0 for i := 0; i < len(newVotes); i++ { - candidateHash, err := newVotes[i].HashSSZ() - if err != nil { - return fmt.Errorf("unable to get hash tree root of eth1data: %v", err) - } - // Check if hash bytes are equal. - match := true - for i := 0; i < len(candidateHash); i++ { - if candidateHash[i] != ethDataHash[i] { - match = false - } - } - if match { + if eth1Data.Equal(newVotes[i]) { numVotes += 1 } } diff --git a/cmd/erigon-cl/core/transition/processing_test.go b/cmd/erigon-cl/core/transition/processing_test.go index 21697dc7bdb..d4644d504bc 100644 --- a/cmd/erigon-cl/core/transition/processing_test.go +++ b/cmd/erigon-cl/core/transition/processing_test.go @@ -211,7 +211,7 @@ func TestProcessRandao(t *testing.T) { for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { s := New(tc.state, &clparams.MainnetBeaconConfig, nil, true) - err := s.ProcessRandao(tc.body.RandaoReveal) + err := s.ProcessRandao(tc.body.RandaoReveal, propInd) if tc.wantErr { if err == nil { t.Errorf("unexpected success, wanted error") diff --git a/cmd/erigon-cl/stages/stages_beacon_state.go b/cmd/erigon-cl/stages/stages_beacon_state.go index d184dc4edab..330f906ad8b 100644 --- a/cmd/erigon-cl/stages/stages_beacon_state.go +++ b/cmd/erigon-cl/stages/stages_beacon_state.go @@ -54,7 +54,7 @@ func SpawnStageBeaconState(cfg StageBeaconStateCfg, s *stagedsync.StageState, tx defer tx.Rollback() } // Initialize the transistor - stateTransistor := transition.New(cfg.state, cfg.beaconCfg, cfg.genesisCfg, false) + stateTransistor := transition.New(cfg.state, cfg.beaconCfg, cfg.genesisCfg, true) endSlot, err := stages.GetStageProgress(tx, stages.BeaconBlocks) if err != nil { diff --git a/go.mod b/go.mod index 146dbb1ae90..5f5a8c4f5a3 100644 --- a/go.mod +++ b/go.mod @@ -222,6 +222,7 @@ require ( github.com/prometheus/client_model v0.2.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect + github.com/protolambda/eth2-shuffle v1.1.0 // indirect github.com/raulk/go-watchdog v1.3.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect diff --git a/go.sum b/go.sum index 40b1edeb30a..481512f1856 100644 --- a/go.sum +++ b/go.sum @@ -839,6 +839,8 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.8.0 h1:ODq8ZFEaYeCaZOJlZZdJA2AbQR98dSHSM1KW/You5mo= github.com/prometheus/procfs v0.8.0/go.mod h1:z7EfXMXOkbkqb9IINtpCn86r/to3BnA0uaxHdg830/4= +github.com/protolambda/eth2-shuffle v1.1.0 h1:gixIBI84IeugTwwHXm8vej1bSSEhueBCSryA4lAKRLU= +github.com/protolambda/eth2-shuffle v1.1.0/go.mod h1:FhA2c0tN15LTC+4T9DNVm+55S7uXTTjQ8TQnBuXlkF8= github.com/prysmaticlabs/fastssz v0.0.0-20220628121656-93dfe28febab h1:Y3PcvUrnneMWLuypZpwPz8P70/DQsz6KgV9JveKpyZs= github.com/prysmaticlabs/fastssz v0.0.0-20220628121656-93dfe28febab/go.mod h1:MA5zShstUwCQaE9faGHgCGvEWUbG87p4SAXINhmCkvg= github.com/prysmaticlabs/go-bitfield v0.0.0-20210809151128-385d8c5e3fb7 h1:0tVE4tdWQK9ZpYygoV7+vS6QkDvQVySboMVEIxBJmXw=