Skip to content
91 changes: 55 additions & 36 deletions data/account/participationRegistry.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"database/sql"
"errors"
"fmt"
"strings"

"github.com/algorand/go-deadlock"

Expand All @@ -46,13 +47,20 @@ type ParticipationRecord struct {
LastVote basics.Round
LastBlockProposal basics.Round
LastCompactCertificate basics.Round
RegisteredFirst basics.Round
RegisteredLast basics.Round
EffectiveFirst basics.Round
EffectiveLast basics.Round

// VRFSecrets
// OneTimeSignatureSecrets
}

var zeroParticipationRecord = ParticipationRecord{}

// IsZero returns true if the object contains zero values.
func (r ParticipationRecord) IsZero() bool {
return r == zeroParticipationRecord
}

// Duplicate creates a copy of the current object. This is required once secrets are stored.
func (r ParticipationRecord) Duplicate() ParticipationRecord {
return ParticipationRecord{
Expand All @@ -64,8 +72,8 @@ func (r ParticipationRecord) Duplicate() ParticipationRecord {
LastVote: r.LastVote,
LastBlockProposal: r.LastBlockProposal,
LastCompactCertificate: r.LastCompactCertificate,
RegisteredFirst: r.RegisteredFirst,
RegisteredLast: r.RegisteredLast,
EffectiveFirst: r.EffectiveFirst,
EffectiveLast: r.EffectiveLast,
}
}

Expand Down Expand Up @@ -119,10 +127,10 @@ type ParticipationRegistry interface {
Delete(id ParticipationID) error

// Get a participation record.
Get(id ParticipationID) (ParticipationRecord, error)
Get(id ParticipationID) ParticipationRecord

// GetAll of the participation records.
GetAll() ([]ParticipationRecord, error)
Copy link
Contributor Author

@winder winder Aug 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the changes in this file are related to simplifying the interface call. There isn't a DB error to pass along anymore, so no need for the error value. The caller can check for a zero value / empty list if they need to know that nothing was found.

GetAll() []ParticipationRecord

// Register updates the EffectiveFirst and EffectiveLast fields. If there are multiple records for the account
// then it is possible for multiple records to be updated.
Expand Down Expand Up @@ -192,8 +200,8 @@ var (
lastVoteRound INTEGER NOT NULL DEFAULT 0,
lastBlockProposalRound INTEGER NOT NULL DEFAULT 0,
lastCompactCertificateRound INTEGER NOT NULL DEFAULT 0,
registeredFirstRound INTEGER NOT NULL DEFAULT 0,
registeredLastRound INTEGER NOT NULL DEFAULT 0
effectiveFirstRound INTEGER NOT NULL DEFAULT 0,
effectiveLastRound INTEGER NOT NULL DEFAULT 0

-- voting BLOB, --* msgpack encoding of ParticipationAccount.voting
)`
Expand All @@ -206,7 +214,7 @@ var (
selectRecords = `SELECT
participationID, account, firstValidRound, lastValidRound, keyDilution,
lastVoteRound, lastBlockProposalRound, lastCompactCertificateRound,
registeredFirstRound, registeredLastRound
effectiveFirstRound, effectiveLastRound
FROM Keysets
INNER JOIN Rolling
ON Keysets.pk = Rolling.pk`
Expand All @@ -216,8 +224,8 @@ var (
SET lastVoteRound=?,
lastBlockProposalRound=?,
lastCompactCertificateRound=?,
registeredFirstRound=?,
registeredLastRound=?
effectiveFirstRound=?,
effectiveLastRound=?
WHERE pk IN (SELECT pk FROM Keysets WHERE participationID=?)`
)

Expand Down Expand Up @@ -331,8 +339,8 @@ func (db *participationDB) Insert(record Participation) (id ParticipationID, err
LastVote: 0,
LastBlockProposal: 0,
LastCompactCertificate: 0,
RegisteredFirst: 0,
RegisteredLast: 0,
EffectiveFirst: 0,
EffectiveLast: 0,
}
}

Expand Down Expand Up @@ -395,8 +403,8 @@ func scanRecords(rows *sql.Rows) ([]ParticipationRecord, error) {
&record.LastVote,
&record.LastBlockProposal,
&record.LastCompactCertificate,
&record.RegisteredFirst,
&record.RegisteredLast,
&record.EffectiveFirst,
&record.EffectiveLast,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -430,26 +438,26 @@ func (db *participationDB) getAllFromDB() (records []ParticipationRecord, err er
return
}

func (db *participationDB) Get(id ParticipationID) (record ParticipationRecord, err error) {
func (db *participationDB) Get(id ParticipationID) ParticipationRecord {
db.mutex.RLock()
defer db.mutex.RUnlock()

record, ok := db.cache[id]
if !ok {
return ParticipationRecord{}, ErrParticipationIDNotFound
return ParticipationRecord{}
}
return record.Duplicate(), nil
return record.Duplicate()
}

func (db *participationDB) GetAll() ([]ParticipationRecord, error) {
func (db *participationDB) GetAll() []ParticipationRecord {
db.mutex.RLock()
defer db.mutex.RUnlock()

results := make([]ParticipationRecord, 0, len(db.cache))
for _, record := range db.cache {
results = append(results, record.Duplicate())
}
return results, nil
return results
}

// updateRollingFields sets all of the rolling fields according to the record object.
Expand All @@ -458,8 +466,8 @@ func (db *participationDB) updateRollingFields(ctx context.Context, tx *sql.Tx,
record.LastVote,
record.LastBlockProposal,
record.LastCompactCertificate,
record.RegisteredFirst,
record.RegisteredLast,
record.EffectiveFirst,
record.EffectiveLast,
record.ParticipationID[:])
if err != nil {
return err
Expand All @@ -482,28 +490,35 @@ func (db *participationDB) updateRollingFields(ctx context.Context, tx *sql.Tx,
}

func recordActive(record ParticipationRecord, on basics.Round) bool {
return record.RegisteredFirst <= on && on <= record.RegisteredLast
return record.EffectiveLast != 0 && record.EffectiveFirst <= on && on <= record.EffectiveLast
}

func (db *participationDB) Register(id ParticipationID, on basics.Round) error {
// Lookup recordToRegister for first/last valid and account.
recordToRegister, err := db.Get(id)
if err != nil {
return err
recordToRegister := db.Get(id)
if recordToRegister.IsZero() {
return ErrParticipationIDNotFound
}

// No-op If the record is already active
if recordActive(recordToRegister, on) {
return nil
}

// round out of valid range.
if on < recordToRegister.FirstValid || on > recordToRegister.LastValid {
return ErrInvalidRegisterRange
}

db.mutex.Lock()
defer db.mutex.Unlock()
updated := make(map[ParticipationID]ParticipationRecord)
err = db.store.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) error {
err := db.store.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) error {
// Disable active key if there is one
for _, record := range db.cache {
if record.Account == recordToRegister.Account && record.ParticipationID != id && recordActive(record, on) {
// TODO: this should probably be "on - 1"
record.RegisteredLast = on
record.EffectiveLast = on
err := db.updateRollingFields(ctx, tx, record)
// Repair the case when no keys were updated
if err == ErrNoKeyForID {
Expand All @@ -519,8 +534,8 @@ func (db *participationDB) Register(id ParticipationID, on basics.Round) error {
}

// Mark registered.
recordToRegister.RegisteredFirst = on
recordToRegister.RegisteredLast = recordToRegister.LastValid
recordToRegister.EffectiveFirst = on
recordToRegister.EffectiveLast = recordToRegister.LastValid

err := db.updateRollingFields(ctx, tx, recordToRegister)
if err == ErrNoKeyForID {
Expand All @@ -537,9 +552,6 @@ func (db *participationDB) Register(id ParticipationID, on basics.Round) error {

// Update cache
if err == nil {
db.mutex.Lock()
defer db.mutex.Unlock()

for id, record := range updated {
delete(db.dirty, id)
db.cache[id] = record
Expand Down Expand Up @@ -596,7 +608,7 @@ func (db *participationDB) Flush() error {
// Verify that the dirty flag has not desynchronized from the cache.
for id := range db.dirty {
if _, ok := db.cache[id]; !ok {
db.log.Warn("participationDB fixing dirty flag de-synchronization for %s", id)
db.log.Warnf("participationDB fixing dirty flag de-synchronization for %s", id)
delete(db.cache, id)
}
}
Expand All @@ -606,13 +618,20 @@ func (db *participationDB) Flush() error {
}

err := db.store.Wdb.Atomic(func(ctx context.Context, tx *sql.Tx) error {
var errorStr strings.Builder
for id := range db.dirty {
err := db.updateRollingFields(ctx, tx, db.cache[id])
// This should only be updating key usage so ignoring missing keys is not a problem.
if err != nil && err != ErrNoKeyForID {
return err
if errorStr.Len() > 0 {
errorStr.WriteString(", ")
}
errorStr.WriteString(err.Error())
}
}
if errorStr.Len() > 0 {
return errors.New(errorStr.String())
}
return nil
})

Expand All @@ -626,7 +645,7 @@ func (db *participationDB) Flush() error {

func (db *participationDB) Close() {
if err := db.Flush(); err != nil {
db.log.Warn("participationDB unhandled error during Close/Flush: %w", err)
db.log.Warnf("participationDB unhandled error during Close/Flush: %w", err)
}

db.store.Close()
Expand Down
Loading