diff --git a/dot/core/interface.go b/dot/core/interface.go index daa6d38a0e..8c5710b097 100644 --- a/dot/core/interface.go +++ b/dot/core/interface.go @@ -45,8 +45,8 @@ type BlockState interface { GetFinalisedHash(uint64, uint64) (common.Hash, error) GetImportedBlockNotifierChannel() chan *types.Block FreeImportedBlockNotifierChannel(ch chan *types.Block) - RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) - UnregisterFinalisedChannel(id byte) + GetFinalisedNotifierChannel() chan *types.FinalisationInfo + FreeFinalisedNotifierChannel(ch chan *types.FinalisationInfo) HighestCommonAncestor(a, b common.Hash) (common.Hash, error) SubChain(start, end common.Hash) ([]common.Hash, error) GetBlockBody(hash common.Hash) (*types.Body, error) diff --git a/dot/core/mocks/block_state.go b/dot/core/mocks/block_state.go index 6e5387bae4..933e7f4e54 100644 --- a/dot/core/mocks/block_state.go +++ b/dot/core/mocks/block_state.go @@ -142,6 +142,11 @@ func (_m *MockBlockState) BestBlockStateRoot() (common.Hash, error) { return r0, r1 } +// FreeFinalisedNotifierChannel provides a mock function with given fields: ch +func (_m *MockBlockState) FreeFinalisedNotifierChannel(ch chan *types.FinalisationInfo) { + _m.Called(ch) +} + // FreeImportedBlockNotifierChannel provides a mock function with given fields: ch func (_m *MockBlockState) FreeImportedBlockNotifierChannel(ch chan *types.Block) { _m.Called(ch) @@ -294,6 +299,22 @@ func (_m *MockBlockState) GetFinalisedHeader(_a0 uint64, _a1 uint64) (*types.Hea return r0, r1 } +// GetFinalisedNotifierChannel provides a mock function with given fields: +func (_m *MockBlockState) GetFinalisedNotifierChannel() chan *types.FinalisationInfo { + ret := _m.Called() + + var r0 chan *types.FinalisationInfo + if rf, ok := ret.Get(0).(func() chan *types.FinalisationInfo); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(chan *types.FinalisationInfo) + } + } + + return r0 +} + // GetImportedBlockNotifierChannel provides a mock function with given fields: func (_m *MockBlockState) GetImportedBlockNotifierChannel() chan *types.Block { ret := _m.Called() @@ -391,27 +412,6 @@ func (_m *MockBlockState) HighestCommonAncestor(a common.Hash, b common.Hash) (c return r0, r1 } -// RegisterFinalizedChannel provides a mock function with given fields: ch -func (_m *MockBlockState) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { - ret := _m.Called(ch) - - var r0 byte - if rf, ok := ret.Get(0).(func(chan<- *types.FinalisationInfo) byte); ok { - r0 = rf(ch) - } else { - r0 = ret.Get(0).(byte) - } - - var r1 error - if rf, ok := ret.Get(1).(func(chan<- *types.FinalisationInfo) error); ok { - r1 = rf(ch) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // StoreRuntime provides a mock function with given fields: _a0, _a1 func (_m *MockBlockState) StoreRuntime(_a0 common.Hash, _a1 runtime.Instance) { _m.Called(_a0, _a1) @@ -439,8 +439,3 @@ func (_m *MockBlockState) SubChain(start common.Hash, end common.Hash) ([]common return r0, r1 } - -// UnregisterFinalisedChannel provides a mock function with given fields: id -func (_m *MockBlockState) UnregisterFinalisedChannel(id byte) { - _m.Called(id) -} diff --git a/dot/digest/digest.go b/dot/digest/digest.go index 30c6c492a2..ca87ac61a5 100644 --- a/dot/digest/digest.go +++ b/dot/digest/digest.go @@ -47,9 +47,8 @@ type Handler struct { grandpaState GrandpaState // block notification channels - imported chan *types.Block - finalised chan *types.FinalisationInfo - finalisedID byte + imported chan *types.Block + finalised chan *types.FinalisationInfo // GRANDPA changes grandpaScheduledChange *grandpaChange @@ -75,12 +74,7 @@ type resume struct { func NewHandler(blockState BlockState, epochState EpochState, grandpaState GrandpaState) (*Handler, error) { imported := blockState.GetImportedBlockNotifierChannel() - finalised := make(chan *types.FinalisationInfo, 16) - - fid, err := blockState.RegisterFinalizedChannel(finalised) - if err != nil { - return nil, err - } + finalised := blockState.GetFinalisedNotifierChannel() ctx, cancel := context.WithCancel(context.Background()) @@ -92,7 +86,6 @@ func NewHandler(blockState BlockState, epochState EpochState, grandpaState Grand grandpaState: grandpaState, imported: imported, finalised: finalised, - finalisedID: fid, }, nil } @@ -107,8 +100,7 @@ func (h *Handler) Start() error { func (h *Handler) Stop() error { h.cancel() h.blockState.FreeImportedBlockNotifierChannel(h.imported) - h.blockState.UnregisterFinalisedChannel(h.finalisedID) - close(h.finalised) + h.blockState.FreeFinalisedNotifierChannel(h.finalised) return nil } diff --git a/dot/digest/interface.go b/dot/digest/interface.go index f467b0a7d0..72d4668bd1 100644 --- a/dot/digest/interface.go +++ b/dot/digest/interface.go @@ -28,8 +28,8 @@ type BlockState interface { BestBlockHeader() (*types.Header, error) GetImportedBlockNotifierChannel() chan *types.Block FreeImportedBlockNotifierChannel(ch chan *types.Block) - RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) - UnregisterFinalisedChannel(id byte) + GetFinalisedNotifierChannel() chan *types.FinalisationInfo + FreeFinalisedNotifierChannel(ch chan *types.FinalisationInfo) } // EpochState is the interface for state.EpochState diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index d79d9e5583..20f134b083 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -40,8 +40,8 @@ type BlockAPI interface { GetJustification(hash common.Hash) ([]byte, error) GetImportedBlockNotifierChannel() chan *types.Block FreeImportedBlockNotifierChannel(ch chan *types.Block) - RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) - UnregisterFinalisedChannel(id byte) + GetFinalisedNotifierChannel() chan *types.FinalisationInfo + FreeFinalisedNotifierChannel(ch chan *types.FinalisationInfo) SubChain(start, end common.Hash) ([]common.Hash, error) RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (uint32, error) UnregisterRuntimeUpdatedChannel(id uint32) bool diff --git a/dot/rpc/modules/api_mocks.go b/dot/rpc/modules/api_mocks.go index b9f9802eaa..3d3765b7a4 100644 --- a/dot/rpc/modules/api_mocks.go +++ b/dot/rpc/modules/api_mocks.go @@ -32,9 +32,8 @@ func NewMockBlockAPI() *modulesmocks.MockBlockAPI { m.On("GetHighestFinalisedHash").Return(common.Hash{}, nil) m.On("GetImportedBlockNotifierChannel").Return(make(chan *types.Block, 5)) m.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) - m.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) - m.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")).Return(byte(0), nil) - m.On("UnregisterFinalizedChannel", mock.AnythingOfType("uint8")) + m.On("GetFinalisedNotifierChannel").Return(make(chan *types.FinalisationInfo, 5)) + m.On("FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) m.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(make([]byte, 10), nil) m.On("HasJustification", mock.AnythingOfType("common.Hash")).Return(true, nil) m.On("SubChain", mock.AnythingOfType("common.Hash"), mock.AnythingOfType("common.Hash")).Return(make([]common.Hash, 0), nil) diff --git a/dot/rpc/modules/mocks/block_api.go b/dot/rpc/modules/mocks/block_api.go index d343fe1627..cf6d9382cf 100644 --- a/dot/rpc/modules/mocks/block_api.go +++ b/dot/rpc/modules/mocks/block_api.go @@ -34,6 +34,11 @@ func (_m *MockBlockAPI) BestBlockHash() common.Hash { return r0 } +// FreeFinalisedNotifierChannel provides a mock function with given fields: ch +func (_m *MockBlockAPI) FreeFinalisedNotifierChannel(ch chan *types.FinalisationInfo) { + _m.Called(ch) +} + // FreeImportedBlockNotifierChannel provides a mock function with given fields: ch func (_m *MockBlockAPI) FreeImportedBlockNotifierChannel(ch chan *types.Block) { _m.Called(ch) @@ -108,6 +113,22 @@ func (_m *MockBlockAPI) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, e return r0, r1 } +// GetFinalisedNotifierChannel provides a mock function with given fields: +func (_m *MockBlockAPI) GetFinalisedNotifierChannel() chan *types.FinalisationInfo { + ret := _m.Called() + + var r0 chan *types.FinalisationInfo + if rf, ok := ret.Get(0).(func() chan *types.FinalisationInfo); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(chan *types.FinalisationInfo) + } + } + + return r0 +} + // GetHeader provides a mock function with given fields: hash func (_m *MockBlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { ret := _m.Called(hash) @@ -214,27 +235,6 @@ func (_m *MockBlockAPI) HasJustification(hash common.Hash) (bool, error) { return r0, r1 } -// RegisterFinalizedChannel provides a mock function with given fields: ch -func (_m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { - ret := _m.Called(ch) - - var r0 byte - if rf, ok := ret.Get(0).(func(chan<- *types.FinalisationInfo) byte); ok { - r0 = rf(ch) - } else { - r0 = ret.Get(0).(byte) - } - - var r1 error - if rf, ok := ret.Get(1).(func(chan<- *types.FinalisationInfo) error); ok { - r1 = rf(ch) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // RegisterRuntimeUpdatedChannel provides a mock function with given fields: ch func (_m *MockBlockAPI) RegisterRuntimeUpdatedChannel(ch chan<- runtime.Version) (uint32, error) { ret := _m.Called(ch) @@ -279,11 +279,6 @@ func (_m *MockBlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.H return r0, r1 } -// UnregisterFinalisedChannel provides a mock function with given fields: id -func (_m *MockBlockAPI) UnregisterFinalisedChannel(id byte) { - _m.Called(id) -} - // UnregisterRuntimeUpdatedChannel provides a mock function with given fields: id func (_m *MockBlockAPI) UnregisterRuntimeUpdatedChannel(id uint32) bool { ret := _m.Called(id) diff --git a/dot/rpc/subscription/listeners.go b/dot/rpc/subscription/listeners.go index 6299d7599c..32cab72054 100644 --- a/dot/rpc/subscription/listeners.go +++ b/dot/rpc/subscription/listeners.go @@ -178,7 +178,6 @@ func (l *BlockListener) Stop() error { type BlockFinalizedListener struct { channel chan *types.FinalisationInfo wsconn *WSConn - chanID byte subID uint32 done chan struct{} cancel chan struct{} @@ -189,7 +188,7 @@ type BlockFinalizedListener struct { func (l *BlockFinalizedListener) Listen() { go func() { defer func() { - l.wsconn.BlockAPI.UnregisterFinalisedChannel(l.chanID) + l.wsconn.BlockAPI.FreeFinalisedNotifierChannel(l.channel) close(l.done) }() @@ -229,12 +228,11 @@ type AllBlocksListener struct { finalizedChan chan *types.FinalisationInfo importedChan chan *types.Block - wsconn *WSConn - finalizedChanID byte - subID uint32 - done chan struct{} - cancel chan struct{} - cancelTimeout time.Duration + wsconn *WSConn + subID uint32 + done chan struct{} + cancel chan struct{} + cancelTimeout time.Duration } func newAllBlockListener(conn *WSConn) *AllBlocksListener { @@ -243,7 +241,6 @@ func newAllBlockListener(conn *WSConn) *AllBlocksListener { done: make(chan struct{}, 1), cancelTimeout: defaultCancelTimeout, wsconn: conn, - finalizedChan: make(chan *types.FinalisationInfo, DEFAULT_BUFFER_SIZE), } } @@ -252,9 +249,8 @@ func (l *AllBlocksListener) Listen() { go func() { defer func() { l.wsconn.BlockAPI.FreeImportedBlockNotifierChannel(l.importedChan) - l.wsconn.BlockAPI.UnregisterFinalisedChannel(l.finalizedChanID) + l.wsconn.BlockAPI.FreeFinalisedNotifierChannel(l.finalizedChan) - close(l.finalizedChan) close(l.done) }() @@ -307,16 +303,15 @@ func (l *AllBlocksListener) Stop() error { // ExtrinsicSubmitListener to handle listening for extrinsic events type ExtrinsicSubmitListener struct { - wsconn *WSConn - subID uint32 - extrinsic types.Extrinsic - importedChan chan *types.Block - importedHash common.Hash - finalisedChan chan *types.FinalisationInfo - finalisedChanID byte - done chan struct{} - cancel chan struct{} - cancelTimeout time.Duration + wsconn *WSConn + subID uint32 + extrinsic types.Extrinsic + importedChan chan *types.Block + importedHash common.Hash + finalisedChan chan *types.FinalisationInfo + done chan struct{} + cancel chan struct{} + cancelTimeout time.Duration } // NewExtrinsicSubmitListener constructor to build new ExtrinsicSubmitListener @@ -338,7 +333,7 @@ func (l *ExtrinsicSubmitListener) Listen() { go func() { defer func() { l.wsconn.BlockAPI.FreeImportedBlockNotifierChannel(l.importedChan) - l.wsconn.BlockAPI.UnregisterFinalisedChannel(l.finalisedChanID) + l.wsconn.BlockAPI.FreeFinalisedNotifierChannel(l.finalisedChan) close(l.done) close(l.finalisedChan) }() @@ -459,7 +454,6 @@ type GrandpaJustificationListener struct { done chan struct{} wsconn *WSConn subID uint32 - finalisedChID byte finalisedCh chan *types.FinalisationInfo } @@ -468,7 +462,7 @@ func (g *GrandpaJustificationListener) Listen() { // listen for finalised headers go func() { defer func() { - g.wsconn.BlockAPI.UnregisterFinalisedChannel(g.finalisedChID) + g.wsconn.BlockAPI.FreeFinalisedNotifierChannel(g.finalisedCh) close(g.done) }() diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index a192c27b94..506a524b47 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -146,7 +146,7 @@ func TestBlockFinalizedListener_Listen(t *testing.T) { defer cancel() BlockAPI := new(mocks.MockBlockAPI) - BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.On("FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) wsconn.BlockAPI = BlockAPI @@ -165,7 +165,7 @@ func TestBlockFinalizedListener_Listen(t *testing.T) { defer func() { require.NoError(t, bfl.Stop()) time.Sleep(time.Millisecond * 10) - BlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) }() notifyChan <- &types.FinalisationInfo{ @@ -198,7 +198,7 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { BlockAPI := new(mocks.MockBlockAPI) BlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) - BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.On("FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) wsconn.BlockAPI = BlockAPI @@ -227,7 +227,7 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { time.Sleep(time.Millisecond * 10) BlockAPI.AssertCalled(t, "FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) - BlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) }() notifyImportedChan <- block @@ -272,7 +272,7 @@ func TestGrandpaJustification_Listen(t *testing.T) { blockStateMock := new(mocks.MockBlockAPI) blockStateMock.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) - blockStateMock.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + blockStateMock.On("FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) wsconn.BlockAPI = blockStateMock finchannel := make(chan *types.FinalisationInfo) diff --git a/dot/rpc/subscription/websocket.go b/dot/rpc/subscription/websocket.go index ff8742407f..167a0a9ccf 100644 --- a/dot/rpc/subscription/websocket.go +++ b/dot/rpc/subscription/websocket.go @@ -29,7 +29,6 @@ import ( "sync/atomic" "github.com/ChainSafe/gossamer/dot/rpc/modules" - "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/runtime" log "github.com/ChainSafe/log15" @@ -236,7 +235,6 @@ func (c *WSConn) initBlockListener(reqID float64, _ interface{}) (Listener, erro func (c *WSConn) initBlockFinalizedListener(reqID float64, _ interface{}) (Listener, error) { bfl := &BlockFinalizedListener{ - channel: make(chan *types.FinalisationInfo), cancel: make(chan struct{}, 1), done: make(chan struct{}, 1), cancelTimeout: defaultCancelTimeout, @@ -248,11 +246,7 @@ func (c *WSConn) initBlockFinalizedListener(reqID float64, _ interface{}) (Liste return nil, fmt.Errorf("error BlockAPI not set") } - var err error - bfl.chanID, err = c.BlockAPI.RegisterFinalizedChannel(bfl.channel) - if err != nil { - return nil, err - } + bfl.channel = c.BlockAPI.GetFinalisedNotifierChannel() c.mu.Lock() @@ -276,13 +270,7 @@ func (c *WSConn) initAllBlocksListerner(reqID float64, _ interface{}) (Listener, } listener.importedChan = c.BlockAPI.GetImportedBlockNotifierChannel() - - var err error - listener.finalizedChanID, err = c.BlockAPI.RegisterFinalizedChannel(listener.finalizedChan) - if err != nil { - c.safeSendError(reqID, nil, "could not register finalised channel") - return nil, fmt.Errorf("could not register finalised channel") - } + listener.finalizedChan = c.BlockAPI.GetFinalisedNotifierChannel() c.mu.Lock() listener.subID = atomic.AddUint32(&c.qtyListeners, 1) @@ -309,10 +297,7 @@ func (c *WSConn) initExtrinsicWatch(reqID float64, params interface{}) (Listener esl.importedChan = c.BlockAPI.GetImportedBlockNotifierChannel() - esl.finalisedChanID, err = c.BlockAPI.RegisterFinalizedChannel(esl.finalisedChan) - if err != nil { - return nil, err - } + esl.finalisedChan = c.BlockAPI.GetFinalisedNotifierChannel() c.mu.Lock() @@ -377,15 +362,10 @@ func (c *WSConn) initGrandpaJustificationListener(reqID float64, _ interface{}) cancel: make(chan struct{}, 1), done: make(chan struct{}, 1), wsconn: c, - finalisedCh: make(chan *types.FinalisationInfo, 1), cancelTimeout: defaultCancelTimeout, } - var err error - jl.finalisedChID, err = c.BlockAPI.RegisterFinalizedChannel(jl.finalisedCh) - if err != nil { - return nil, err - } + jl.finalisedCh = c.BlockAPI.GetFinalisedNotifierChannel() c.mu.Lock() diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index e87605c288..e4eb23effb 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -1,7 +1,6 @@ package subscription import ( - "errors" "fmt" "math/big" "testing" @@ -218,7 +217,6 @@ func TestWSConn_HandleComm(t *testing.T) { require.NoError(t, err) require.Equal(t, `{"jsonrpc":"2.0","method":"author_extrinsicUpdate","params":{"result":"ready","subscription":8}}`+"\n", string(msg)) - var fCh chan<- *types.FinalisationInfo mockedJust := grandpa.Justification{ Round: 1, Commit: grandpa.Commit{ @@ -232,15 +230,12 @@ func TestWSConn_HandleComm(t *testing.T) { require.NoError(t, err) BlockAPI := new(mocks.MockBlockAPI) - BlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). - Run(func(args mock.Arguments) { - ch := args.Get(0).(chan<- *types.FinalisationInfo) - fCh = ch - }). - Return(uint8(4), nil) + + fCh := make(chan *types.FinalisationInfo, 5) + BlockAPI.On("GetFinalisedNotifierChannel").Return(fCh) BlockAPI.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) - BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.On("FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) wsconn.BlockAPI = BlockAPI listener, err := wsconn.initGrandpaJustificationListener(0, nil) @@ -291,26 +286,11 @@ func TestSubscribeAllHeads(t *testing.T) { wsconn.BlockAPI = mockBlockAPI - mockBlockAPI.On("GetImportedBlockNotifierChannel").Return(make(chan *types.Block)).Once() - mockBlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). - Return(uint8(0), errors.New("failed")).Once() - - _, err = wsconn.initAllBlocksListerner(1, nil) - require.Error(t, err, "could not register finalised channel") - c.ReadMessage() - - finalizedChanID := uint8(11) - - var fCh chan<- *types.FinalisationInfo iCh := make(chan *types.Block) mockBlockAPI.On("GetImportedBlockNotifierChannel").Return(iCh).Once() - mockBlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). - Run(func(args mock.Arguments) { - ch := args.Get(0).(chan<- *types.FinalisationInfo) - fCh = ch - }). - Return(finalizedChanID, nil).Once() + fCh := make(chan *types.FinalisationInfo) + mockBlockAPI.On("GetFinalisedNotifierChannel").Return(fCh).Once() l, err := wsconn.initAllBlocksListerner(1, nil) require.NoError(t, err) @@ -369,9 +349,8 @@ func TestSubscribeAllHeads(t *testing.T) { require.Equal(t, []byte(expected+"\n"), msg) mockBlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) - mockBlockAPI.On("UnregisterFinalisedChannel", finalizedChanID) + mockBlockAPI.On("FreeFinalisedNotifierChannel", mock.AnythingOfType("chan *types.FinalisationInfo")) require.NoError(t, l.Stop()) mockBlockAPI.On("FreeImportedBlockNotifierChannel", mock.AnythingOfType("chan *types.Block")) - mockBlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", finalizedChanID) } diff --git a/dot/state/block.go b/dot/state/block.go index 7f3690fe5c..c340c0f85b 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -53,10 +53,9 @@ type BlockState struct { // block notifiers imported map[chan *types.Block]struct{} - finalised map[byte]chan<- *types.FinalisationInfo + finalised map[chan *types.FinalisationInfo]struct{} finalisedLock sync.RWMutex importedLock sync.RWMutex - finalisedBytePool *common.BytePool runtimeUpdateSubscriptionsLock sync.RWMutex runtimeUpdateSubscriptions map[uint32]chan<- runtime.Version @@ -75,7 +74,7 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), imported: make(map[chan *types.Block]struct{}), - finalised: make(map[byte]chan<- *types.FinalisationInfo), + finalised: make(map[chan *types.FinalisationInfo]struct{}), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), } @@ -91,7 +90,6 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e return nil, fmt.Errorf("failed to get last finalised hash: %w", err) } - bs.finalisedBytePool = common.NewBytePool256() return bs, nil } @@ -102,7 +100,7 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block baseState: NewBaseState(db), db: chaindb.NewTable(db, blockPrefix), imported: make(map[chan *types.Block]struct{}), - finalised: make(map[byte]chan<- *types.FinalisationInfo), + finalised: make(map[chan *types.FinalisationInfo]struct{}), pruneKeyCh: make(chan *types.Header, pruneKeyBufferSize), runtimeUpdateSubscriptions: make(map[uint32]chan<- runtime.Version), } @@ -134,7 +132,6 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block return nil, err } - bs.finalisedBytePool = common.NewBytePool256() return bs, nil } diff --git a/dot/state/block_notify.go b/dot/state/block_notify.go index ce58e0304e..88924071f6 100644 --- a/dot/state/block_notify.go +++ b/dot/state/block_notify.go @@ -39,22 +39,16 @@ func (bs *BlockState) GetImportedBlockNotifierChannel() chan *types.Block { return ch } -// RegisterFinalizedChannel registers a channel for block notification upon block finalisation. -// It returns the channel ID (used for unregistering the channel) -func (bs *BlockState) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { - bs.finalisedLock.RLock() - - id, err := bs.finalisedBytePool.Get() - if err != nil { - return 0, err - } +//nolint +// GetFinalisedNotifierChannel function to retrieve a finalized block notifier channel +func (bs *BlockState) GetFinalisedNotifierChannel() chan *types.FinalisationInfo { + bs.finalisedLock.Lock() + defer bs.finalisedLock.Unlock() - bs.finalisedLock.RUnlock() + ch := make(chan *types.FinalisationInfo, DEFAULT_BUFFER_SIZE) + bs.finalised[ch] = struct{}{} - bs.finalisedLock.Lock() - bs.finalised[id] = ch - bs.finalisedLock.Unlock() - return id, nil + return ch } // FreeImportedBlockNotifierChannel to free imported block notifier channel @@ -64,17 +58,13 @@ func (bs *BlockState) FreeImportedBlockNotifierChannel(ch chan *types.Block) { delete(bs.imported, ch) } -// UnregisterFinalisedChannel removes the block finalisation notification channel with the given ID. -// A channel must be unregistered before closing it. -func (bs *BlockState) UnregisterFinalisedChannel(id byte) { +//nolint +// FreeFinalisedNotifierChannel to free finalized notifier channel +func (bs *BlockState) FreeFinalisedNotifierChannel(ch chan *types.FinalisationInfo) { bs.finalisedLock.Lock() defer bs.finalisedLock.Unlock() - delete(bs.finalised, id) - err := bs.finalisedBytePool.Put(id) - if err != nil { - logger.Error("failed to unregister finalised channel", "error", err) - } + delete(bs.finalised, ch) } func (bs *BlockState) notifyImported(block *types.Block) { @@ -117,8 +107,8 @@ func (bs *BlockState) notifyFinalized(hash common.Hash, round, setID uint64) { SetID: setID, } - for _, ch := range bs.finalised { - go func(ch chan<- *types.FinalisationInfo) { + for ch := range bs.finalised { + go func(ch chan *types.FinalisationInfo) { select { case ch <- info: default: diff --git a/dot/state/block_notify_test.go b/dot/state/block_notify_test.go index cd9bcafe6c..c9b1e366a7 100644 --- a/dot/state/block_notify_test.go +++ b/dot/state/block_notify_test.go @@ -59,11 +59,9 @@ func TestFreeImportedBlockNotifierChannel(t *testing.T) { func TestFinalizedChannel(t *testing.T) { bs := newTestBlockState(t, testGenesisHeader) - ch := make(chan *types.FinalisationInfo, 3) - id, err := bs.RegisterFinalizedChannel(ch) - require.NoError(t, err) + ch := bs.GetFinalisedNotifierChannel() - defer bs.UnregisterFinalisedChannel(id) + defer bs.FreeFinalisedNotifierChannel(ch) chain, _ := AddBlocksToState(t, bs, 3) @@ -118,13 +116,9 @@ func TestFinalizedChannel_Multi(t *testing.T) { num := 5 chs := make([]chan *types.FinalisationInfo, num) - ids := make([]byte, num) - var err error for i := 0; i < num; i++ { - chs[i] = make(chan *types.FinalisationInfo) - ids[i], err = bs.RegisterFinalizedChannel(chs[i]) - require.NoError(t, err) + chs[i] = bs.GetFinalisedNotifierChannel() } chain, _ := AddBlocksToState(t, bs, 1) @@ -149,8 +143,8 @@ func TestFinalizedChannel_Multi(t *testing.T) { bs.SetFinalisedHash(chain[0].Hash(), 1, 0) wg.Wait() - for _, id := range ids { - bs.UnregisterFinalisedChannel(id) + for _, ch := range chs { + bs.FreeFinalisedNotifierChannel(ch) } } diff --git a/lib/common/bytepool.go b/lib/common/bytepool.go deleted file mode 100644 index f708a04f47..0000000000 --- a/lib/common/bytepool.go +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright 2019 ChainSafe Systems (ON) Corp. -// This file is part of gossamer. -// -// The gossamer library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The gossamer library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the gossamer library. If not, see . - -package common - -import "fmt" - -// BytePool struct to hold byte objects that will be contained in pool -type BytePool struct { - c chan byte -} - -// NewBytePool256 creates and initialises pool with 256 entries -func NewBytePool256() *BytePool { - bp := NewBytePool(256) - for i := 0; i < 256; i++ { - _ = bp.Put(byte(i)) - } - return bp -} - -// NewBytePool creates a new empty byte pool with capacity of size -func NewBytePool(size int) (bp *BytePool) { - return &BytePool{ - c: make(chan byte, size), - } -} - -// Get gets a Buffer from the BytePool, or creates a new one if none are -// available in the pool. -func (bp *BytePool) Get() (b byte, err error) { - select { - case b = <-bp.c: - default: - err = fmt.Errorf("all slots used") - } - return -} - -// Put returns the given Buffer to the BytePool. -func (bp *BytePool) Put(b byte) error { - select { - case bp.c <- b: - return nil - default: - return fmt.Errorf("pool is full") - } -} - -// Len returns the number of items currently pooled. -func (bp *BytePool) Len() int { - return len(bp.c) -} diff --git a/lib/common/bytepool_test.go b/lib/common/bytepool_test.go deleted file mode 100644 index 997de16516..0000000000 --- a/lib/common/bytepool_test.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2019 ChainSafe Systems (ON) Corp. -// This file is part of gossamer. -// -// The gossamer library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The gossamer library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the gossamer library. If not, see . - -package common - -import ( - "math/rand" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestBytePool(t *testing.T) { - bp := NewBytePool(5) - require.Equal(t, 0, bp.Len()) - - for i := 0; i < 5; i++ { - err := bp.Put(generateID()) - require.NoError(t, err) - } - err := bp.Put(generateID()) - require.EqualError(t, err, "pool is full") - require.Equal(t, 5, bp.Len()) - - for i := 0; i < 5; i++ { - _, err := bp.Get() // nolint - require.NoError(t, err) - } - _, err = bp.Get() - require.EqualError(t, err, "all slots used") -} - -func TestBytePool256(t *testing.T) { - bp := NewBytePool256() - require.Equal(t, 256, bp.Len()) - - for i := 0; i < 256; i++ { - _, err := bp.Get() // nolint - require.NoError(t, err) - } - _, err := bp.Get() - require.EqualError(t, err, "all slots used") -} - -func generateID() byte { - // skipcq: GSC-G404 - id := rand.Intn(256) //nolint - return byte(id) -} diff --git a/lib/grandpa/grandpa.go b/lib/grandpa/grandpa.go index 3d2960c7f2..3900355359 100644 --- a/lib/grandpa/grandpa.go +++ b/lib/grandpa/grandpa.go @@ -79,7 +79,6 @@ type Service struct { // channels for communication with other services in chan *networkVoteMessage // only used to receive *VoteMessage finalisedCh chan *types.FinalisationInfo - finalisedChID byte neighbourMessage *NeighbourMessage // cached neighbour message } @@ -139,11 +138,7 @@ func NewService(cfg *Config) (*Service, error) { return nil, err } - finalisedCh := make(chan *types.FinalisationInfo, 16) - fid, err := cfg.BlockState.RegisterFinalizedChannel(finalisedCh) - if err != nil { - return nil, err - } + finalisedCh := cfg.BlockState.GetFinalisedNotifierChannel() round, err := cfg.GrandpaState.GetLatestRound() if err != nil { @@ -171,7 +166,6 @@ func NewService(cfg *Config) (*Service, error) { resumed: make(chan struct{}), network: cfg.Network, finalisedCh: finalisedCh, - finalisedChID: fid, } s.messageHandler = NewMessageHandler(s, s.blockState) @@ -212,8 +206,7 @@ func (s *Service) Stop() error { s.cancel() - s.blockState.UnregisterFinalisedChannel(s.finalisedChID) - close(s.finalisedCh) + s.blockState.FreeFinalisedNotifierChannel(s.finalisedCh) if !s.authority { return nil diff --git a/lib/grandpa/state.go b/lib/grandpa/state.go index c83352a007..6b5ff3c839 100644 --- a/lib/grandpa/state.go +++ b/lib/grandpa/state.go @@ -43,8 +43,8 @@ type BlockState interface { BlocktreeAsString() string GetImportedBlockNotifierChannel() chan *types.Block FreeImportedBlockNotifierChannel(ch chan *types.Block) - RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) - UnregisterFinalisedChannel(id byte) + GetFinalisedNotifierChannel() chan *types.FinalisationInfo + FreeFinalisedNotifierChannel(ch chan *types.FinalisationInfo) SetJustification(hash common.Hash, data []byte) error HasJustification(hash common.Hash) (bool, error) GetJustification(hash common.Hash) ([]byte, error)