diff --git a/types/mempool/priority_nonce.go b/types/mempool/priority_nonce.go index e8db75b9f5c8..eba8311b3205 100644 --- a/types/mempool/priority_nonce.go +++ b/types/mempool/priority_nonce.go @@ -29,6 +29,7 @@ type priorityNonceMempool struct { senderIndices map[string]*skiplist.SkipList scores map[txMeta]txMeta onRead func(tx sdk.Tx) + maxTx int } type priorityNonceIterator struct { @@ -84,13 +85,24 @@ func txMetaLess(a, b any) int { type PriorityNonceMempoolOption func(*priorityNonceMempool) -// WithOnRead sets a callback to be called when a tx is read from the mempool. -func WithOnRead(onRead func(tx sdk.Tx)) PriorityNonceMempoolOption { +// PriorityNonceWithOnRead sets a callback to be called when a tx is read from the mempool. +func PriorityNonceWithOnRead(onRead func(tx sdk.Tx)) PriorityNonceMempoolOption { return func(mp *priorityNonceMempool) { mp.onRead = onRead } } +// PriorityNonceWithMaxTx sets the maximum number of transactions allowed in the mempool with the semantics: +// +// <0: disabled, `Insert` is a no-op +// 0: unlimited +// >0: maximum number of transactions allowed +func PriorityNonceWithMaxTx(maxTx int) PriorityNonceMempoolOption { + return func(mp *priorityNonceMempool) { + mp.maxTx = maxTx + } +} + // DefaultPriorityMempool returns a priorityNonceMempool with no options. func DefaultPriorityMempool() Mempool { return NewPriorityMempool() @@ -123,6 +135,12 @@ func NewPriorityMempool(opts ...PriorityNonceMempoolOption) Mempool { // Inserting a duplicate tx with a different priority overwrites the existing tx, // changing the total order of the mempool. func (mp *priorityNonceMempool) Insert(ctx context.Context, tx sdk.Tx) error { + if mp.maxTx > 0 && mp.CountTx() >= mp.maxTx { + return ErrMempoolTxMaxCapacity + } else if mp.maxTx < 0 { + return nil + } + sigs, err := tx.(signing.SigVerifiableTx).GetSignaturesV2() if err != nil { return err diff --git a/types/mempool/priority_nonce_test.go b/types/mempool/priority_nonce_test.go index f75b59b4a263..d5d76d747b54 100644 --- a/types/mempool/priority_nonce_test.go +++ b/types/mempool/priority_nonce_test.go @@ -373,7 +373,7 @@ func validateOrder(mtxs []sdk.Tx) error { func (s *MempoolTestSuite) TestRandomGeneratedTxs() { s.iterations = 0 - s.mempool = mempool.NewPriorityMempool(mempool.WithOnRead(func(tx sdk.Tx) { + s.mempool = mempool.NewPriorityMempool(mempool.PriorityNonceWithOnRead(func(tx sdk.Tx) { s.iterations++ })) t := s.T() @@ -582,3 +582,60 @@ func TestTxOrderN(t *testing.T) { fmt.Printf("%s, %d, %d\n", tx.address, tx.priority, tx.nonce) } } + +func TestTxLimit(t *testing.T) { + accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2) + ctx := sdk.NewContext(nil, tmproto.Header{}, false, log.NewNopLogger()) + sa := accounts[0].Address + sb := accounts[1].Address + + txs := []testTx{ + {priority: 20, nonce: 1, address: sa}, + {priority: 21, nonce: 1, address: sb}, + {priority: 15, nonce: 2, address: sa}, + {priority: 88, nonce: 2, address: sb}, + {priority: 66, nonce: 3, address: sa}, + {priority: 15, nonce: 3, address: sb}, + {priority: 20, nonce: 4, address: sa}, + {priority: 21, nonce: 4, address: sb}, + {priority: 88, nonce: 5, address: sa}, + {priority: 66, nonce: 5, address: sb}, + } + + // unlimited + mp := mempool.NewPriorityMempool(mempool.PriorityNonceWithMaxTx(0)) + for i, tx := range txs { + c := ctx.WithPriority(tx.priority) + require.NoError(t, mp.Insert(c, tx)) + require.Equal(t, i+1, mp.CountTx()) + } + mp = mempool.NewPriorityMempool() + for i, tx := range txs { + c := ctx.WithPriority(tx.priority) + require.NoError(t, mp.Insert(c, tx)) + require.Equal(t, i+1, mp.CountTx()) + } + + // limit: 3 + mp = mempool.NewPriorityMempool(mempool.PriorityNonceWithMaxTx(3)) + for i, tx := range txs { + c := ctx.WithPriority(tx.priority) + err := mp.Insert(c, tx) + if i < 3 { + require.NoError(t, err) + require.Equal(t, i+1, mp.CountTx()) + } else { + require.ErrorIs(t, err, mempool.ErrMempoolTxMaxCapacity) + require.Equal(t, 3, mp.CountTx()) + } + } + + // disabled + mp = mempool.NewPriorityMempool(mempool.PriorityNonceWithMaxTx(-1)) + for _, tx := range txs { + c := ctx.WithPriority(tx.priority) + err := mp.Insert(c, tx) + require.NoError(t, err) + require.Equal(t, 0, mp.CountTx()) + } +}