Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: purge mempool based on chain events #479

Merged
merged 1 commit into from
Mar 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 88 additions & 53 deletions mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@ import (
"github.com/prometheus/client_golang/prometheus/promauto"
)

const (
txsubmissionMempoolExpiration = 1 * time.Hour
txSubmissionMempoolExpirationPeriod = 1 * time.Minute
)

const (
AddTransactionEventType event.EventType = "mempool.add_tx"
RemoveTransactionEventType event.EventType = "mempool.remove_tx"
Expand Down Expand Up @@ -90,9 +85,8 @@ func NewMempool(
} else {
m.logger = logger
}
// TODO: replace this with purging based on on-chain TXs (#388)
// Schedule initial mempool expired cleanup
m.scheduleRemoveExpired()
// Subscribe to chain update events
go m.processChainEvents()
// Init metrics
promautoFactory := promauto.With(promRegistry)
m.metrics.txsProcessedNum = promautoFactory.NewCounter(
Expand Down Expand Up @@ -133,28 +127,61 @@ func (m *Mempool) Consumer(connId ouroboros.ConnectionId) *MempoolConsumer {
return m.consumers[connId]
}

// TODO: replace this with purging based on on-chain TXs (#388)
func (m *Mempool) removeExpired() {
m.Lock()
defer m.Unlock()
expiredBefore := time.Now().Add(-txsubmissionMempoolExpiration)
// We iterate backward to avoid issues with shifting indexes when deleting
for i := len(m.transactions) - 1; i >= 0; i-- {
tx := m.transactions[i]
if tx.LastSeen.Before(expiredBefore) {
m.removeTransaction(tx.Hash)
m.logger.Debug(
"removed expired transaction",
"component", "mempool",
"tx_hash", tx.Hash,
)
func (m *Mempool) processChainEvents() {
chainBlockSubId, chainBlockChan := m.eventBus.Subscribe(state.ChainBlockEventType)
chainRollbackSubId, chainRollbackChan := m.eventBus.Subscribe(state.ChainRollbackEventType)
defer func() {
m.eventBus.Unsubscribe(state.ChainBlockEventType, chainBlockSubId)
m.eventBus.Unsubscribe(state.ChainRollbackEventType, chainRollbackSubId)
}()
lastValidationTime := time.Now()
var ok bool
for {
// Wait for chain event
select {
case _, ok = <-chainBlockChan:
if !ok {
return
}
case _, ok = <-chainRollbackChan:
if !ok {
return
}
}
// Only purge once every 30 seconds when there are more blocks available
if time.Since(lastValidationTime) < 30*time.Second && len(chainBlockChan) > 0 {
continue
}
m.Lock()
// Re-validate each TX in mempool
// We iterate backward to avoid issues with shifting indexes when deleting
for i := len(m.transactions) - 1; i >= 0; i-- {
tx := m.transactions[i]
// Decode transaction
tmpTx, err := ledger.NewTransactionFromCbor(tx.Type, tx.Cbor)
if err != nil {
m.removeTransactionByIndex(i)
m.logger.Error(
"removed transaction after decode failure",
"component", "mempool",
"tx_hash", tx.Hash,
"error", err,
)
continue
}
// Validate transaction
if err := m.ledgerState.ValidateTx(tmpTx); err != nil {
m.removeTransactionByIndex(i)
m.logger.Debug(
"removed transaction after re-validation failure",
"component", "mempool",
"tx_hash", tx.Hash,
"error", err,
)
}
}
m.Unlock()
}
m.scheduleRemoveExpired()
}

func (m *Mempool) scheduleRemoveExpired() {
_ = time.AfterFunc(txSubmissionMempoolExpirationPeriod, m.removeExpired)
}

func (m *Mempool) AddTransaction(txType uint, txBytes []byte) error {
Expand Down Expand Up @@ -261,32 +288,40 @@ func (m *Mempool) RemoveTransaction(txHash string) {
func (m *Mempool) removeTransaction(txHash string) bool {
for txIdx, tx := range m.transactions {
if tx.Hash == txHash {
m.transactions = slices.Delete(
m.transactions,
txIdx,
txIdx+1,
)
m.metrics.txsInMempool.Dec()
m.metrics.mempoolBytes.Sub(float64(len(tx.Cbor)))
// Update consumer indexes to reflect removed TX
for _, consumer := range m.consumers {
// Decrement consumer index if the consumer has reached the removed TX
if consumer.nextTxIdx >= txIdx {
consumer.nextTxIdx--
}
}
// Generate event
m.eventBus.Publish(
RemoveTransactionEventType,
event.NewEvent(
RemoveTransactionEventType,
RemoveTransactionEvent{
Hash: tx.Hash,
},
),
)
return true
return m.removeTransactionByIndex(txIdx)
}
}
return false
}

func (m *Mempool) removeTransactionByIndex(txIdx int) bool {
if txIdx >= len(m.transactions) {
return false
}
tx := m.transactions[txIdx]
m.transactions = slices.Delete(
m.transactions,
txIdx,
txIdx+1,
)
m.metrics.txsInMempool.Dec()
m.metrics.mempoolBytes.Sub(float64(len(tx.Cbor)))
// Update consumer indexes to reflect removed TX
for _, consumer := range m.consumers {
// Decrement consumer index if the consumer has reached the removed TX
if consumer.nextTxIdx >= txIdx {
consumer.nextTxIdx--
}
}
// Generate event
m.eventBus.Publish(
RemoveTransactionEventType,
event.NewEvent(
RemoveTransactionEventType,
RemoveTransactionEvent{
Hash: tx.Hash,
},
),
)
return true
}