diff --git a/dot/network/block_announce.go b/dot/network/block_announce.go index ab31795702..05e1c9b71f 100644 --- a/dot/network/block_announce.go +++ b/dot/network/block_announce.go @@ -212,14 +212,6 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err return errors.New("genesis hash mismatch") } - // if peer has higher best block than us, begin syncing - latestHeader, err := s.blockState.BestBlockHeader() - if err != nil { - return err - } - - bestBlockNum := big.NewInt(int64(bhs.BestBlockNumber)) - np, ok := s.notificationsProtocols[BlockAnnounceMsgType] if !ok { // this should never happen. @@ -239,6 +231,14 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err data.handshake = hs + // if peer has higher best block than us, begin syncing + latestHeader, err := s.blockState.BestBlockHeader() + if err != nil { + return err + } + + bestBlockNum := big.NewInt(int64(bhs.BestBlockNumber)) + // check if peer block number is greater than host block number if latestHeader.Number.Cmp(bestBlockNum) >= 0 { return nil diff --git a/dot/network/host.go b/dot/network/host.go index 787b0c0e96..79594760e4 100644 --- a/dot/network/host.go +++ b/dot/network/host.go @@ -237,7 +237,7 @@ func (h *host) bootstrap() { // peer (gets the already opened outbound message stream or opens a new one). func (h *host) send(p peer.ID, pid protocol.ID, msg Message) (err error) { // get outbound stream for given peer - s := h.getStream(p, pid) + s := h.getOutboundStream(p, pid) // check if stream needs to be opened if s == nil { @@ -286,10 +286,10 @@ func (h *host) writeToStream(s libp2pnetwork.Stream, msg Message) error { return err } -// getStream returns the outbound message stream for the given peer or returns +// getOutboundStream returns the outbound message stream for the given peer or returns // nil if no outbound message stream exists. For each peer, each host opens an // outbound message stream and writes to the same stream until closed or reset. -func (h *host) getStream(p peer.ID, pid protocol.ID) (stream libp2pnetwork.Stream) { +func (h *host) getOutboundStream(p peer.ID, pid protocol.ID) (stream libp2pnetwork.Stream) { conns := h.h.Network().ConnsToPeer(p) // loop through connections (only one for now) @@ -310,7 +310,7 @@ func (h *host) getStream(p peer.ID, pid protocol.ID) (stream libp2pnetwork.Strea // closeStream closes a stream open to the peer with the given sub-protocol, if it exists. func (h *host) closeStream(p peer.ID, pid protocol.ID) { - stream := h.getStream(p, pid) + stream := h.getOutboundStream(p, pid) if stream != nil { _ = stream.Close() } diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 56430d94cb..aca5f2313e 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -269,7 +269,7 @@ func TestExistingStream(t *testing.T) { } require.NoError(t, err) - stream := nodeA.host.getStream(nodeB.host.id(), nodeB.host.protocolID) + stream := nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID) require.Nil(t, stream, "node A should not have an outbound stream") // node A opens the stream to send the first message @@ -279,7 +279,7 @@ func TestExistingStream(t *testing.T) { time.Sleep(TestMessageTimeout) require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A") - stream = nodeA.host.getStream(nodeB.host.id(), nodeB.host.protocolID) + stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID) require.NotNil(t, stream, "node A should have an outbound stream") // node A uses the stream to send a second message @@ -287,7 +287,7 @@ func TestExistingStream(t *testing.T) { require.NoError(t, err) require.NotNil(t, handlerB.messages[nodeA.host.id()], "node B timeout waiting for message from node A") - stream = nodeA.host.getStream(nodeB.host.id(), nodeB.host.protocolID) + stream = nodeA.host.getOutboundStream(nodeB.host.id(), nodeB.host.protocolID) require.NotNil(t, stream, "node B should have an outbound stream") // node B opens the stream to send the first message @@ -297,7 +297,7 @@ func TestExistingStream(t *testing.T) { time.Sleep(TestMessageTimeout) require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B") - stream = nodeB.host.getStream(nodeA.host.id(), nodeB.host.protocolID) + stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID) require.NotNil(t, stream, "node B should have an outbound stream") // node B uses the stream to send a second message @@ -305,7 +305,7 @@ func TestExistingStream(t *testing.T) { require.NoError(t, err) require.NotNil(t, handlerA.messages[nodeB.host.id()], "node A timeout waiting for message from node B") - stream = nodeB.host.getStream(nodeA.host.id(), nodeB.host.protocolID) + stream = nodeB.host.getOutboundStream(nodeA.host.id(), nodeB.host.protocolID) require.NotNil(t, stream, "node B should have an outbound stream") } diff --git a/dot/network/notifications.go b/dot/network/notifications.go index a733d06b78..98760d6e18 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -106,12 +106,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return errors.New("message is not NotificationsMessage") } - logger.Trace("received message on notifications sub-protocol", "protocol", info.protocolID, - "message", msg, - "peer", stream.Conn().RemotePeer(), - ) - if msg.IsHandshake() { + logger.Trace("received handshake on notifications sub-protocol", "protocol", info.protocolID, + "message", msg, + "peer", stream.Conn().RemotePeer(), + ) + hs, ok := msg.(Handshake) if !ok { return errors.New("failed to convert message to Handshake") @@ -186,6 +186,11 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, return nil } + logger.Debug("received message on notifications sub-protocol", "protocol", info.protocolID, + "message", msg, + "peer", stream.Conn().RemotePeer(), + ) + err := messageHandler(peer, msg) if err != nil { return err diff --git a/dot/network/service.go b/dot/network/service.go index a970812985..6c454d611a 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -304,7 +304,7 @@ func (s *Service) handleConn(conn libp2pnetwork.Conn) { defer info.mapMu.RUnlock() peer := conn.RemotePeer() - if hsData, has := info.getHandshakeData(peer); !has || !hsData.received { + if hsData, has := info.getHandshakeData(peer); !has || !hsData.received { //nolint info.handshakeData.Store(peer, &handshakeData{ validated: false, }) @@ -428,6 +428,9 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, info := s.notificationsProtocols[messageID] + decoder := createDecoder(info, handshakeDecoder, messageDecoder) + handlerWithValidate := s.createNotificationsMessageHandler(info, handshakeValidator, messageHandler) + s.host.registerStreamHandlerWithOverwrite(sub, overwriteProtocol, func(stream libp2pnetwork.Stream) { logger.Trace("received stream", "sub-protocol", sub) conn := stream.Conn() @@ -437,10 +440,6 @@ func (s *Service) RegisterNotificationsProtocol(sub protocol.ID, } p := conn.RemotePeer() - - decoder := createDecoder(info, handshakeDecoder, messageDecoder) - handlerWithValidate := s.createNotificationsMessageHandler(info, handshakeValidator, messageHandler) - s.readStream(stream, p, decoder, handlerWithValidate) }) @@ -537,7 +536,7 @@ func (s *Service) readStream(stream libp2pnetwork.Stream, peer peer.ID, decoder // decode message based on message type msg, err := decoder(msgBytes[:tot], peer) if err != nil { - logger.Trace("Failed to decode message from peer", "peer", peer, "err", err) + logger.Trace("failed to decode message from peer", "protocol", stream.Protocol(), "err", err) continue } diff --git a/dot/network/state.go b/dot/network/state.go index 77725bd7de..72a4831e35 100644 --- a/dot/network/state.go +++ b/dot/network/state.go @@ -30,6 +30,7 @@ type BlockState interface { GenesisHash() common.Hash HasBlockBody(common.Hash) (bool, error) GetFinalizedHeader(round, setID uint64) (*types.Header, error) + GetHashByNumber(num *big.Int) (common.Hash, error) } // Syncer is implemented by the syncing service diff --git a/dot/network/state_test.go b/dot/network/state_test.go index abc41d7a62..a268b9914e 100644 --- a/dot/network/state_test.go +++ b/dot/network/state_test.go @@ -75,3 +75,7 @@ func (mbs *MockBlockState) HasBlockBody(common.Hash) (bool, error) { func (mbs *MockBlockState) GetFinalizedHeader(_, _ uint64) (*types.Header, error) { return mbs.BestBlockHeader() } + +func (mbs *MockBlockState) GetHashByNumber(_ *big.Int) (common.Hash, error) { + return common.Hash{}, nil +} diff --git a/dot/network/sync.go b/dot/network/sync.go index 70d822d577..9cbc5d21e0 100644 --- a/dot/network/sync.go +++ b/dot/network/sync.go @@ -85,12 +85,20 @@ func (s *Service) handleSyncMessage(stream libp2pnetwork.Stream, msg Message) er } var ( - blockRequestSize uint32 = 128 - blockRequestBufferSize int = 6 + blockRequestSize uint32 = 128 + blockRequestBufferSize int = 6 + blockResponseBufferSize int = 6 maxBlockResponseSize uint64 = 1024 * 1024 * 4 // 4mb badPeerThreshold int = -2 protectedPeerThreshold int = 7 + + defaultSlotDuration = time.Second * 6 +) + +var ( + errEmptyResponseData = fmt.Errorf("response data is empty") + errEmptyJustificationData = fmt.Errorf("no justifications in response data") ) type syncPeer struct { @@ -110,13 +118,15 @@ type requestData struct { } type syncQueue struct { - s *Service - ctx context.Context - cancel context.CancelFunc - peerScore *sync.Map // map[peer.ID]int; peers we have successfully synced from before -> their score; score increases on successful response + s *Service + slotDuration time.Duration + ctx context.Context + cancel context.CancelFunc + peerScore *sync.Map // map[peer.ID]int; peers we have successfully synced from before -> their score; score increases on successful response - requestData *sync.Map // map[uint64]requestData; map of start # of request -> requestData - requestCh chan *syncRequest + requestData *sync.Map // map[uint64]requestData; map of start # of request -> requestData + justificationRequestData *sync.Map // map[common.Hash]requestData; map of requests of justifications -> requestData + requestCh chan *syncRequest responses []*types.BlockData responseCh chan []*types.BlockData @@ -133,22 +143,25 @@ func newSyncQueue(s *Service) *syncQueue { ctx, cancel := context.WithCancel(s.ctx) return &syncQueue{ - s: s, - ctx: ctx, - cancel: cancel, - peerScore: new(sync.Map), - requestData: new(sync.Map), - requestCh: make(chan *syncRequest, blockRequestBufferSize), - responses: []*types.BlockData{}, - responseCh: make(chan []*types.BlockData), - benchmarker: newSyncBenchmarker(), - buf: make([]byte, maxBlockResponseSize), + s: s, + slotDuration: defaultSlotDuration, + ctx: ctx, + cancel: cancel, + peerScore: new(sync.Map), + requestData: new(sync.Map), + justificationRequestData: new(sync.Map), + requestCh: make(chan *syncRequest, blockRequestBufferSize), + responses: []*types.BlockData{}, + responseCh: make(chan []*types.BlockData, blockResponseBufferSize), + benchmarker: newSyncBenchmarker(), + buf: make([]byte, maxBlockResponseSize), } } func (q *syncQueue) start() { go q.handleResponseQueue() go q.syncAtHead() + go q.finalizeAtHead() go q.processBlockRequests() go q.processBlockResponses() @@ -169,7 +182,7 @@ func (q *syncQueue) syncAtHead() { for { select { // sleep for average block time TODO: make this configurable from slot duration - case <-time.After(time.Second * 6): + case <-time.After(q.slotDuration): case <-q.ctx.Done(): return } @@ -180,7 +193,7 @@ func (q *syncQueue) syncAtHead() { } // we aren't at the head yet, sleep - if curr.Number.Int64() < q.goal { + if curr.Number.Int64() < q.goal && curr.Number.Cmp(prev.Number) > 0 { prev = curr continue } @@ -395,17 +408,7 @@ func (q *syncQueue) pushRequest(start uint64, numRequests int, to peer.ID) { start := best.Int64() + 1 req := createBlockRequest(start, 0) - if d, has := q.requestData.Load(start); has { - data := d.(requestData) - // we haven't sent the request out yet, or we've already gotten the response - if !data.sent || data.sent && data.received { - logger.Debug("ignoring request, already received data", "start", start) - return - } - } - logger.Debug("pushing request to queue", "start", start) - q.requestData.Store(start, requestData{ received: false, }) @@ -453,7 +456,35 @@ func (q *syncQueue) pushRequest(start uint64, numRequests int, to peer.ID) { func (q *syncQueue) pushResponse(resp *BlockResponseMessage, pid peer.ID) error { if len(resp.BlockData) == 0 { - return fmt.Errorf("response data is empty") + return errEmptyResponseData + } + + startHash := resp.BlockData[0].Hash + if _, has := q.justificationRequestData.Load(startHash); has && !resp.BlockData[0].Header.Exists() { + numJustifications := 0 + justificationResponses := []*types.BlockData{} + + for _, bd := range resp.BlockData { + if bd.Justification.Exists() { + justificationResponses = append(justificationResponses, bd) + numJustifications++ + } + } + + if numJustifications == 0 { + return errEmptyJustificationData + } + + q.updatePeerScore(pid, 1) + q.justificationRequestData.Store(startHash, requestData{ + sent: true, + received: true, + from: pid, + }) + + logger.Info("pushed justification data to queue", "hash", startHash) + q.responseCh <- justificationResponses + return nil } start, end, err := resp.getStartAndEnd() @@ -525,7 +556,7 @@ func (q *syncQueue) trySync(req *syncRequest) { return } - logger.Debug("beginning to send out request", "start", req.req.StartingBlock.Value()) + logger.Trace("beginning to send out request", "start", req.req.StartingBlock.Value()) if len(req.to) != 0 { resp, err := q.syncWithPeer(req.to, req.req) if err == nil { @@ -535,11 +566,11 @@ func (q *syncQueue) trySync(req *syncRequest) { } } - logger.Debug("failed to sync with peer", "peer", req.to, "error", err) + logger.Trace("failed to sync with peer", "peer", req.to, "error", err) q.updatePeerScore(req.to, -1) } - logger.Debug("trying peers in prioritized order...") + logger.Trace("trying peers in prioritized order...") syncPeers := q.getSortedPeers() for _, peer := range syncPeers { @@ -556,19 +587,24 @@ func (q *syncQueue) trySync(req *syncRequest) { } err = q.pushResponse(resp, peer.pid) - if err != nil { - logger.Trace("failed to push block response", "error", err) + if err != nil && err != errEmptyResponseData && err != errEmptyJustificationData { + logger.Debug("failed to push block response", "error", err) } else { return } } - logger.Debug("failed to sync with any peer :(") - if req.req.StartingBlock.IsUint64() { + logger.Trace("failed to sync with any peer :(") + if req.req.StartingBlock.IsUint64() && (req.req.RequestedData&RequestedDataHeader) == 1 { q.requestData.Store(req.req.StartingBlock.Uint64(), requestData{ sent: true, received: false, }) + } else if req.req.StartingBlock.IsHash() && (req.req.RequestedData&RequestedDataHeader) == 0 { + q.justificationRequestData.Store(req.req.StartingBlock.Hash(), requestData{ + sent: true, + received: false, + }) } req.to = "" @@ -613,6 +649,12 @@ func (q *syncQueue) processBlockResponses() { for { select { case data := <-q.responseCh: + // if the response doesn't contain a header, then it's a justification-only response + if !data[0].Header.Exists() { + q.handleBlockJustification(data) + continue + } + q.handleBlockData(data) case <-q.ctx.Done(): return @@ -620,6 +662,32 @@ func (q *syncQueue) processBlockResponses() { } } +func (q *syncQueue) handleBlockJustification(data []*types.BlockData) { + startHash, endHash := data[0].Hash, data[len(data)-1].Hash + logger.Debug("sending justification data to syncer", "start", startHash, "end", endHash) + + _, err := q.s.syncer.ProcessBlockData(data) + if err != nil { + logger.Warn("failed to handle block justifications", "error", err) + return + } + + logger.Debug("finished processing justification data", "start", startHash, "end", endHash) + + // update peer's score + var from peer.ID + + d, ok := q.justificationRequestData.Load(startHash) + if !ok { + // this shouldn't happen + logger.Debug("can't find request data for response!", "start", startHash) + } else { + from = d.(requestData).from + q.updatePeerScore(from, 2) + q.justificationRequestData.Delete(startHash) + } +} + func (q *syncQueue) handleBlockData(data []*types.BlockData) { bestNum, err := q.s.blockState.BestBlockNumber() if err != nil { diff --git a/dot/network/sync_justification.go b/dot/network/sync_justification.go new file mode 100644 index 0000000000..c5310cc144 --- /dev/null +++ b/dot/network/sync_justification.go @@ -0,0 +1,89 @@ +// Copyright 2019 ChainSafe Systems (ON) Corp. +// This file is part of gossamer. +// +// The gossamer library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The gossamer library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the gossamer library. If not, see . + +package network + +import ( + "math/big" + "time" +) + +func (q *syncQueue) finalizeAtHead() { + prev, err := q.s.blockState.GetFinalizedHeader(0, 0) + if err != nil { + logger.Error("failed to get latest finalized block header", "error", err) + return + } + + for { + select { + // sleep for average block time TODO: make this configurable from slot duration + case <-time.After(q.slotDuration * 2): + case <-q.ctx.Done(): + return + } + + curr, err := q.s.blockState.GetFinalizedHeader(0, 0) + if err != nil { + continue + } + + logger.Debug("checking finalized blocks", "curr", curr.Number, "prev", prev.Number) + + if curr.Number.Cmp(prev.Number) > 0 { + prev = curr + continue + } + + prev = curr + + // no new blocks have been finalized, request block justifications from peers + head, err := q.s.blockState.BestBlockNumber() + if err != nil { + continue + } + + start := head.Uint64() - uint64(blockRequestSize) + if curr.Number.Uint64() > start { + start = curr.Number.Uint64() + 1 + } else if int(start) < int(blockRequestSize) { + start = 1 + } + + q.pushJustificationRequest(start) + } +} + +func (q *syncQueue) pushJustificationRequest(start uint64) { + startHash, err := q.s.blockState.GetHashByNumber(big.NewInt(int64(start))) + if err != nil { + logger.Error("failed to get hash for block w/ number", "number", start, "error", err) + return + } + + req := createBlockRequestWithHash(startHash, blockRequestSize) + req.RequestedData = RequestedDataJustification + + logger.Debug("pushing justification request to queue", "start", start, "hash", startHash) + q.justificationRequestData.Store(startHash, requestData{ + received: false, + }) + + q.requestCh <- &syncRequest{ + req: req, + to: "", + } +} diff --git a/dot/network/sync_justification_test.go b/dot/network/sync_justification_test.go new file mode 100644 index 0000000000..00d854471a --- /dev/null +++ b/dot/network/sync_justification_test.go @@ -0,0 +1,168 @@ +// Copyright 2019 ChainSafe Systems (ON) Corp. +// This file is part of gossamer. +// +// The gossamer library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The gossamer library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the gossamer library. If not, see . + +package network + +import ( + "context" + "math/big" + "testing" + "time" + + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/common/optional" + "github.com/ChainSafe/gossamer/lib/utils" + + "github.com/libp2p/go-libp2p-core/peer" + "github.com/stretchr/testify/require" +) + +func TestSyncQueue_PushResponse_Justification(t *testing.T) { + basePath := utils.NewTestBasePath(t, "nodeA") + config := &Config{ + BasePath: basePath, + Port: 7001, + RandSeed: 1, + NoBootstrap: true, + NoMDNS: true, + } + + s := createTestService(t, config) + s.syncQueue.stop() + time.Sleep(time.Second) + + peerID := peer.ID("noot") + msg := &BlockResponseMessage{ + BlockData: []*types.BlockData{}, + } + + for i := 0; i < int(blockRequestSize); i++ { + msg.BlockData = append(msg.BlockData, &types.BlockData{ + Hash: common.Hash{byte(i)}, + Justification: optional.NewBytes(true, []byte{1}), + }) + } + + s.syncQueue.justificationRequestData.Store(common.Hash{byte(0)}, requestData{}) + err := s.syncQueue.pushResponse(msg, peerID) + require.NoError(t, err) + require.Equal(t, 1, len(s.syncQueue.responseCh)) + data, ok := s.syncQueue.justificationRequestData.Load(common.Hash{byte(0)}) + require.True(t, ok) + require.Equal(t, requestData{ + sent: true, + received: true, + from: peerID, + }, data) +} + +func TestSyncQueue_PushResponse_EmptyJustification(t *testing.T) { + basePath := utils.NewTestBasePath(t, "nodeA") + config := &Config{ + BasePath: basePath, + Port: 7001, + RandSeed: 1, + NoBootstrap: true, + NoMDNS: true, + } + + s := createTestService(t, config) + s.syncQueue.stop() + time.Sleep(time.Second) + + peerID := peer.ID("noot") + msg := &BlockResponseMessage{ + BlockData: []*types.BlockData{}, + } + + for i := 0; i < int(blockRequestSize); i++ { + msg.BlockData = append(msg.BlockData, &types.BlockData{ + Hash: common.Hash{byte(i)}, + Justification: optional.NewBytes(false, nil), + }) + } + + s.syncQueue.justificationRequestData.Store(common.Hash{byte(0)}, &requestData{}) + err := s.syncQueue.pushResponse(msg, peerID) + require.Equal(t, errEmptyJustificationData, err) +} + +func TestSyncQueue_processBlockResponses_Justification(t *testing.T) { + q := newTestSyncQueue(t) + q.stop() + time.Sleep(time.Second) + q.ctx = context.Background() + + go func() { + q.responseCh <- []*types.BlockData{ + { + Hash: common.Hash{byte(0)}, + Header: optional.NewHeader(false, nil), + Body: optional.NewBody(false, nil), + Receipt: optional.NewBytes(false, nil), + MessageQueue: optional.NewBytes(false, nil), + Justification: optional.NewBytes(true, []byte{1}), + }, + } + }() + + peerID := peer.ID("noot") + q.justificationRequestData.Store(common.Hash{byte(0)}, requestData{ + from: peerID, + }) + + go q.processBlockResponses() + time.Sleep(time.Second) + + _, has := q.justificationRequestData.Load(common.Hash{byte(0)}) + require.False(t, has) + + score, ok := q.peerScore.Load(peerID) + require.True(t, ok) + require.Equal(t, 2, score) +} + +func TestSyncQueue_finalizeAtHead(t *testing.T) { + q := newTestSyncQueue(t) + q.stop() + time.Sleep(time.Second) + q.ctx = context.Background() + q.slotDuration = time.Millisecond * 200 + + hash, err := q.s.blockState.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + + go q.finalizeAtHead() + time.Sleep(time.Second) + + data, has := q.justificationRequestData.Load(hash) + require.True(t, has) + require.Equal(t, requestData{}, data) + + expected := createBlockRequestWithHash(hash, blockRequestSize) + expected.RequestedData = RequestedDataJustification + + select { + case req := <-q.requestCh: + require.Equal(t, &syncRequest{ + req: expected, + to: "", + }, req) + case <-time.After(time.Second): + t.Fatal("did not receive request") + } +} diff --git a/dot/network/transaction.go b/dot/network/transaction.go index 5eb1f7fa97..5131bc23ef 100644 --- a/dot/network/transaction.go +++ b/dot/network/transaction.go @@ -101,9 +101,7 @@ func (tm *TransactionMessage) IsHandshake() bool { return false } -type transactionHandshake struct { - Roles byte -} +type transactionHandshake struct{} // SubProtocol returns the transactions sub-protocol func (hs *transactionHandshake) SubProtocol() string { @@ -112,22 +110,16 @@ func (hs *transactionHandshake) SubProtocol() string { // String formats a transactionHandshake as a string func (hs *transactionHandshake) String() string { - return fmt.Sprintf("transactionHandshake Roles=%d", - hs.Roles) + return "transactionHandshake" } // Encode encodes a transactionHandshake message using SCALE func (hs *transactionHandshake) Encode() ([]byte, error) { - return scale.Encode(hs) + return []byte{}, nil } // Decode the message into a transactionHandshake func (hs *transactionHandshake) Decode(in []byte) error { - msg, err := scale.Decode(in, hs) - if err != nil { - return err - } - hs.Roles = msg.(*transactionHandshake).Roles return nil } @@ -147,19 +139,11 @@ func (hs *transactionHandshake) IsHandshake() bool { } func (s *Service) getTransactionHandshake() (Handshake, error) { - return &transactionHandshake{ - Roles: s.cfg.Roles, - }, nil + return &transactionHandshake{}, nil } func decodeTransactionHandshake(in []byte) (Handshake, error) { - if len(in) < 1 { - return nil, errors.New("invalid handshake") - } - - return &transactionHandshake{ - Roles: in[0], - }, nil + return &transactionHandshake{}, nil } func validateTransactionHandshake(_ peer.ID, _ Handshake) error { diff --git a/dot/network/transaction_test.go b/dot/network/transaction_test.go index db47f69e4b..07876168d9 100644 --- a/dot/network/transaction_test.go +++ b/dot/network/transaction_test.go @@ -29,9 +29,7 @@ import ( ) func TestDecodeTransactionHandshake(t *testing.T) { - testHandshake := &transactionHandshake{ - Roles: 4, - } + testHandshake := &transactionHandshake{} enc, err := testHandshake.Encode() require.NoError(t, err) diff --git a/dot/network/utils.go b/dot/network/utils.go index 49d3a9c054..62e9c53c00 100644 --- a/dot/network/utils.go +++ b/dot/network/utils.go @@ -185,11 +185,11 @@ func readStream(stream libp2pnetwork.Stream, buf []byte) (int, error) { if err == io.EOF { return 0, err } else if err != nil { - return 0, err // TODO: read bytes read from readLEB128ToUint64 + return 0, err // TODO: return bytes read from readLEB128ToUint64 } if length == 0 { - return 0, err // TODO: read bytes read from readLEB128ToUint64 + return 0, err // TODO: return bytes read from readLEB128ToUint64 } // TODO: check if length > len(buf), if so probably log.Crit diff --git a/dot/rpc/modules/system_test.go b/dot/rpc/modules/system_test.go index 08e3512faf..c1f982530c 100644 --- a/dot/rpc/modules/system_test.go +++ b/dot/rpc/modules/system_test.go @@ -87,6 +87,10 @@ func (s *mockBlockState) GetFinalizedHeader(_, _ uint64) (*types.Header, error) return s.BestBlockHeader() } +func (s *mockBlockState) GetHashByNumber(_ *big.Int) (common.Hash, error) { + return common.Hash{}, nil +} + type mockTransactionHandler struct{} func (h *mockTransactionHandler) HandleTransactionMessage(_ *network.TransactionMessage) error { diff --git a/dot/state/block.go b/dot/state/block.go index 7c2f33acd1..302d92fe5e 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -264,6 +264,16 @@ func (bs *BlockState) GetHeader(hash common.Hash) (*types.Header, error) { return result, err } +// GetHashByNumber returns the block hash given the number +func (bs *BlockState) GetHashByNumber(num *big.Int) (common.Hash, error) { + bh, err := bs.db.Get(headerHashKey(num.Uint64())) + if err != nil { + return common.Hash{}, fmt.Errorf("cannot get block %d: %s", num, err) + } + + return common.NewHash(bh), nil +} + // GetHeaderByNumber returns a block header given a number func (bs *BlockState) GetHeaderByNumber(num *big.Int) (*types.Header, error) { bh, err := bs.db.Get(headerHashKey(num.Uint64())) diff --git a/dot/state/block_test.go b/dot/state/block_test.go index 844e092d33..0a54db54a4 100644 --- a/dot/state/block_test.go +++ b/dot/state/block_test.go @@ -365,3 +365,29 @@ func TestFinalization_DeleteBlock(t *testing.T) { // } } } + +func TestGetHashByNumber(t *testing.T) { + bs := newTestBlockState(t, testGenesisHeader) + + res, err := bs.GetHashByNumber(big.NewInt(0)) + require.NoError(t, err) + require.Equal(t, bs.genesisHash, res) + + header := &types.Header{ + Number: big.NewInt(1), + Digest: types.Digest{}, + ParentHash: testGenesisHeader.Hash(), + } + + block := &types.Block{ + Header: header, + Body: &types.Body{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + } + + err = bs.AddBlock(block) + require.NoError(t, err) + + res, err = bs.GetHashByNumber(big.NewInt(1)) + require.NoError(t, err) + require.Equal(t, header.Hash(), res) +} diff --git a/dot/sync/syncer.go b/dot/sync/syncer.go index c5a96dd23c..b04aadb864 100644 --- a/dot/sync/syncer.go +++ b/dot/sync/syncer.go @@ -187,6 +187,11 @@ func (s *Service) ProcessBlockData(data []*types.BlockData) (int, error) { logger.Debug("failed to add block to blocktree", "hash", bd.Hash, "error", err) } + if bd.Justification != nil && bd.Justification.Exists() { + logger.Debug("handling Justification...", "number", header.Number, "hash", bd.Hash) + s.handleJustification(header, bd.Justification.Value()) + } + continue } @@ -251,7 +256,7 @@ func (s *Service) ProcessBlockData(data []*types.BlockData) (int, error) { logger.Debug("block processed", "hash", bd.Hash) } - if bd.Justification != nil && bd.Justification.Exists() { + if bd.Justification != nil && bd.Justification.Exists() && header != nil { logger.Debug("handling Justification...", "number", bd.Number(), "hash", bd.Hash) s.handleJustification(header, bd.Justification.Value()) } diff --git a/lib/grandpa/network.go b/lib/grandpa/network.go index 5b1b444638..f596282b5f 100644 --- a/lib/grandpa/network.go +++ b/lib/grandpa/network.go @@ -104,7 +104,7 @@ func (s *Service) registerProtocol() error { func (s *Service) getHandshake() (Handshake, error) { return &GrandpaHandshake{ - Roles: 0, // TODO: are roles returned? + Roles: 1, // TODO: don't hard-code this }, nil }