Skip to content

chore: Implement attestation message handling #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions modules/network/keeper/fixtures_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package keeper

import (
"context"
"maps"
"slices"
"strings"
"time"

"cosmossdk.io/math"
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
cmttypes "github.com/cometbft/cometbft/types"
"github.com/cosmos/cosmos-sdk/crypto/keys/ed25519"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
stakingtypes "github.com/cosmos/cosmos-sdk/x/staking/types"
"github.com/libp2p/go-libp2p/core/crypto"

rollkittypes "github.com/rollkit/rollkit/types"

"github.com/rollkit/go-execution-abci/modules/network/types"
)

func HeaderFixture(signer *ed25519.PrivKey, appHash []byte, mutators ...func(*rollkittypes.SignedHeader)) *rollkittypes.SignedHeader {
header := rollkittypes.Header{
BaseHeader: rollkittypes.BaseHeader{
Height: 10,
Time: uint64(time.Now().UnixNano()),
ChainID: "testing",
},
Version: rollkittypes.Version{Block: 1, App: 1},
ProposerAddress: signer.PubKey().Address(),
AppHash: appHash,
DataHash: []byte("data_hash"),
ConsensusHash: []byte("consensus_hash"),
ValidatorHash: []byte("validator_hash"),
}
signedHeader := &rollkittypes.SignedHeader{
Header: header,
Signature: appHash,
Signer: rollkittypes.Signer{PubKey: must(crypto.UnmarshalEd25519PublicKey(signer.PubKey().Bytes()))},
}
for _, m := range mutators {
m(signedHeader)
}
return signedHeader
}

func VoteFixture(myAppHash []byte, voteSigner *ed25519.PrivKey, mutators ...func(vote *cmtproto.Vote)) *cmtproto.Vote {
const chainID = "testing"

vote := &cmtproto.Vote{
Type: cmtproto.PrecommitType,
Height: 10,
Round: 0,
BlockID: cmtproto.BlockID{Hash: myAppHash, PartSetHeader: cmtproto.PartSetHeader{Total: 1, Hash: myAppHash}},
Timestamp: time.Now().UTC(),
ValidatorAddress: voteSigner.PubKey().Address(),
ValidatorIndex: 0,
}
vote.Signature = must(voteSigner.Sign(cmttypes.VoteSignBytes(chainID, vote)))

for _, m := range mutators {
m(vote)
}
return vote
}

var _ types.StakingKeeper = &MockStakingKeeper{}

type MockStakingKeeper struct {
activeSet map[string]stakingtypes.Validator
}

func NewMockStakingKeeper() MockStakingKeeper {
return MockStakingKeeper{
activeSet: make(map[string]stakingtypes.Validator),
}
}

func (m *MockStakingKeeper) SetValidator(ctx context.Context, validator stakingtypes.Validator) error {
m.activeSet[validator.GetOperator()] = validator
return nil
}
func (m MockStakingKeeper) GetAllValidators(ctx context.Context) (validators []stakingtypes.Validator, err error) {
return slices.SortedFunc(maps.Values(m.activeSet), func(v1 stakingtypes.Validator, v2 stakingtypes.Validator) int {
return strings.Compare(v1.OperatorAddress, v2.OperatorAddress)
}), nil
}
func (m MockStakingKeeper) GetValidator(ctx context.Context, addr sdk.ValAddress) (validator stakingtypes.Validator, err error) {
// First try to find the validator by address
validator, found := m.activeSet[addr.String()]
if found {
return validator, nil
}

//// If not found by address, try to find by public key address
//addrStr := addr.String()
//for valAddrStr, pubKey := range m.pubKeys {
// if pubKey.Address().String() == addrStr {
// validator, found = m.activeSet[valAddrStr]
// if found {
// return validator, nil
// }
// }
//}

return validator, sdkerrors.ErrNotFound
}

func (m MockStakingKeeper) GetLastValidators(ctx context.Context) (validators []stakingtypes.Validator, err error) {
for _, validator := range m.activeSet {
if validator.IsBonded() { // Assuming IsBonded() identifies if a validator is in the last validators
validators = append(validators, validator)
}
}
return
}

func (m MockStakingKeeper) GetLastTotalPower(ctx context.Context) (math.Int, error) {
return math.NewInt(int64(len(m.activeSet))), nil
}
162 changes: 115 additions & 47 deletions modules/network/keeper/msg_server.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package keeper

import (
"bytes"
"context"
"errors"
"fmt"

"cosmossdk.io/collections"
sdkerr "cosmossdk.io/errors"
"cosmossdk.io/math"
cmtproto "github.com/cometbft/cometbft/proto/tendermint/types"
cmttypes "github.com/cometbft/cometbft/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
govtypes "github.com/cosmos/cosmos-sdk/x/gov/types"
"github.com/cosmos/gogoproto/proto"

"github.com/rollkit/go-execution-abci/modules/network/types"
)
Expand All @@ -30,90 +33,155 @@ var _ types.MsgServer = msgServer{}
func (k msgServer) Attest(goCtx context.Context, msg *types.MsgAttest) (*types.MsgAttestResponse, error) {
ctx := sdk.UnwrapSDKContext(goCtx)

if k.GetParams(ctx).SignMode == types.SignMode_SIGN_MODE_CHECKPOINT &&
!k.IsCheckpointHeight(ctx, msg.Height) {
return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "height %d is not a checkpoint", msg.Height)
}
has, err := k.IsInAttesterSet(ctx, msg.Validator)
if err != nil {
return nil, sdkerr.Wrapf(err, "in attester set")
if err := k.validateAttestation(ctx, msg); err != nil {
return nil, err
}
if !has {
return nil, sdkerr.Wrapf(sdkerrors.ErrUnauthorized, "validator %s not in attester set", msg.Validator)
// can vote only for the last epoch
if delta := ctx.BlockHeight() - msg.Height; delta < 0 || delta > int64(k.GetParams(ctx).EpochLength) {
return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "exceeded voting window: %d blocks", delta)
}

index, found := k.GetValidatorIndex(ctx, msg.Validator)
valIndexPos, found := k.GetValidatorIndex(ctx, msg.Validator)
if !found {
return nil, sdkerr.Wrapf(sdkerrors.ErrNotFound, "validator index not found for %s", msg.Validator)
}

// todo (Alex): we need to set a limit to not have validators attest old blocks. Also make sure that this relates with
// the retention period for pruning
bitmap, err := k.GetAttestationBitmap(ctx, msg.Height)
if err != nil && !errors.Is(err, collections.ErrNotFound) {
return nil, sdkerr.Wrap(err, "get attestation bitmap")
vote, err := k.verifyVote(ctx, msg)
if err != nil {
return nil, err
}
if bitmap == nil {

if err := k.updateAttestationBitmap(ctx, msg, valIndexPos); err != nil {
return nil, sdkerr.Wrap(err, "update attestation bitmap")
}

if err := k.SetSignature(ctx, msg.Height, msg.Validator, vote.Signature); err != nil {
return nil, sdkerr.Wrap(err, "store signature")
}

if err := k.updateEpochBitmap(ctx, uint64(msg.Height), valIndexPos); err != nil {
return nil, err
}

// Emit event
ctx.EventManager().EmitEvent(
sdk.NewEvent(
types.TypeMsgAttest,
sdk.NewAttribute("validator", msg.Validator),
sdk.NewAttribute("height", math.NewInt(msg.Height).String()),
),
)
return &types.MsgAttestResponse{}, nil
}

func (k msgServer) updateEpochBitmap(ctx sdk.Context, votedEpoch uint64, index uint16) error {
epochBitmap := k.GetEpochBitmap(ctx, votedEpoch)
if epochBitmap == nil {
validators, err := k.stakingKeeper.GetLastValidators(ctx)
if err != nil {
return nil, err
return err
}
numValidators := 0
for _, v := range validators {
if v.IsBonded() {
numValidators++
}
}
bitmap = k.bitmapHelper.NewBitmap(numValidators)
epochBitmap = k.bitmapHelper.NewBitmap(numValidators)
}

if k.bitmapHelper.IsSet(bitmap, int(index)) {
return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "validator %s already attested for height %d", msg.Validator, msg.Height)
k.bitmapHelper.SetBit(epochBitmap, int(index))
if err := k.SetEpochBitmap(ctx, votedEpoch, epochBitmap); err != nil {
return sdkerr.Wrap(err, "set epoch bitmap")
}
return nil
}

// TODO: Verify the vote signature here once we implement vote parsing
// validateAttestation validates the attestation request
func (k msgServer) validateAttestation(ctx sdk.Context, msg *types.MsgAttest) error {
if k.GetParams(ctx).SignMode == types.SignMode_SIGN_MODE_CHECKPOINT &&
!k.IsCheckpointHeight(ctx, msg.Height) {
return sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "height %d is not a checkpoint", msg.Height)
}

// Set the bit
k.bitmapHelper.SetBit(bitmap, int(index))
if err := k.SetAttestationBitmap(ctx, msg.Height, bitmap); err != nil {
return nil, sdkerr.Wrap(err, "set attestation bitmap")
has, err := k.IsInAttesterSet(ctx, msg.Validator)
if err != nil {
return sdkerr.Wrapf(err, "in attester set")
}
if !has {
return sdkerr.Wrapf(sdkerrors.ErrUnauthorized, "validator %s not in attester set", msg.Validator)
}
return nil
}

// Store signature using the new collection method
if err := k.SetSignature(ctx, msg.Height, msg.Validator, msg.Vote); err != nil {
return nil, sdkerr.Wrap(err, "store signature")
// updateAttestationBitmap handles bitmap operations for attestation
func (k msgServer) updateAttestationBitmap(ctx sdk.Context, msg *types.MsgAttest, index uint16) error {
bitmap, err := k.GetAttestationBitmap(ctx, msg.Height)
if err != nil && !sdkerr.IsOf(err, collections.ErrNotFound) {
return err
}

epoch := k.GetCurrentEpoch(ctx)
epochBitmap := k.GetEpochBitmap(ctx, epoch)
if epochBitmap == nil {
if bitmap == nil {
validators, err := k.stakingKeeper.GetLastValidators(ctx)
if err != nil {
return nil, err
return err
}
numValidators := 0
for _, v := range validators {
if v.IsBonded() {
numValidators++
}
}
epochBitmap = k.bitmapHelper.NewBitmap(numValidators)
bitmap = k.bitmapHelper.NewBitmap(numValidators)
}
k.bitmapHelper.SetBit(epochBitmap, int(index))
if err := k.SetEpochBitmap(ctx, epoch, epochBitmap); err != nil {
return nil, sdkerr.Wrap(err, "set epoch bitmap")

if k.bitmapHelper.IsSet(bitmap, int(index)) {
return sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "validator %s already attested for height %d", msg.Validator, msg.Height)
}

// Emit event
ctx.EventManager().EmitEvent(
sdk.NewEvent(
types.TypeMsgAttest,
sdk.NewAttribute("validator", msg.Validator),
sdk.NewAttribute("height", math.NewInt(msg.Height).String()),
),
)
k.bitmapHelper.SetBit(bitmap, int(index))

return &types.MsgAttestResponse{}, nil
if err := k.SetAttestationBitmap(ctx, msg.Height, bitmap); err != nil {
return sdkerr.Wrap(err, "set attestation bitmap")
}
return nil
}

// verifyVote verifies the vote signature and block hash
func (k msgServer) verifyVote(ctx sdk.Context, msg *types.MsgAttest) (*cmtproto.Vote, error) {
var vote cmtproto.Vote
if err := proto.Unmarshal(msg.Vote, &vote); err != nil {
return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "unmarshal vote: %s", err)
}
if msg.Height != vote.Height {
return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "vote height does not match attestation height")
}
if len(vote.Signature) == 0 {
return nil, sdkerrors.ErrInvalidRequest.Wrap("empty signature")
}

// todo (Alex): validate app hash match, vote clock drift

valAddress, err := sdk.ValAddressFromBech32(msg.Validator)
if err != nil {
return nil, sdkerr.Wrap(err, "invalid validator address")
}
validator, err := k.stakingKeeper.GetValidator(ctx, valAddress)
if err != nil {
return nil, sdkerr.Wrapf(err, "get validator")
}
pubKey, err := validator.ConsPubKey()
if err != nil {
return nil, sdkerr.Wrapf(err, "pubkey")
}
if !bytes.Equal(pubKey.Address().Bytes(), vote.ValidatorAddress) {
return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "pubkey address does not match validator address")
}
voteSignBytes := cmttypes.VoteSignBytes(ctx.ChainID(), &vote)
if !pubKey.VerifySignature(voteSignBytes, vote.Signature) {
return nil, sdkerr.Wrapf(sdkerrors.ErrInvalidRequest, "invalid vote signature")
}

return &vote, nil
}

// JoinAttesterSet handles MsgJoinAttesterSet
Expand Down
Loading
Loading