Skip to content

Commit

Permalink
Add nil checks for protowire (#1570)
Browse files Browse the repository at this point in the history
* Handle errors in p2p handshake better

* Add a new errNil in protowire

* Add nil checks for all protowire.toAppMessage() functions for the p2p

* Add nil checks for all protowire.toAppMessage() functions for the RPC

* Add nil check for protwire KaspadMessage

Co-authored-by: Svarog <feanorr@gmail.com>
  • Loading branch information
elichai and svarogg authored Mar 2, 2021
1 parent 1548ed9 commit 7829a9f
Show file tree
Hide file tree
Showing 67 changed files with 1,882 additions and 604 deletions.
14 changes: 13 additions & 1 deletion app/protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,20 @@ func (m *Manager) routerInitializer(router *routerpkg.Router, netConnection *net

peer, err := handshake.HandleHandshake(m.context, netConnection, receiveVersionRoute,
sendVersionRoute, router.OutgoingRoute())

if err != nil {
m.handleError(err, netConnection, router.OutgoingRoute())
// non-blocking read from channel
select {
case innerError := <-errChan:
if errors.Is(err, routerpkg.ErrRouteClosed) {
m.handleError(innerError, netConnection, router.OutgoingRoute())
} else {
log.Errorf("Peer %s sent invalid message: %s", netConnection, innerError)
m.handleError(err, netConnection, router.OutgoingRoute())
}
default:
m.handleError(err, netConnection, router.OutgoingRoute())
}
return
}
defer m.context.RemoveFromPeers(peer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ import (
"github.com/pkg/errors"
)

var errorNil = errors.New("a required field is nil")

func (x *Hash) toDomain() (*externalapi.DomainHash, error) {
if x == nil {
return nil, errors.Wrap(errorNil, "Hash is nil")
}
return externalapi.NewDomainHashFromByteSlice(x.Bytes)
}

Expand Down Expand Up @@ -43,6 +48,9 @@ func domainHashesToProto(hashes []*externalapi.DomainHash) []*Hash {
}

func (x *TransactionId) toDomain() (*externalapi.DomainTransactionID, error) {
if x == nil {
return nil, errors.Wrap(errorNil, "TransactionId is nil")
}
return transactionid.FromBytes(x.Bytes)
}

Expand Down Expand Up @@ -74,7 +82,7 @@ func wireTransactionIDsToProto(ids []*externalapi.DomainTransactionID) []*Transa

func (x *SubnetworkId) toDomain() (*externalapi.DomainSubnetworkID, error) {
if x == nil {
return nil, nil
return nil, errors.Wrap(errorNil, "SubnetworkId is nil")
}
return subnetworks.FromBytes(x.Bytes)
}
Expand All @@ -89,6 +97,9 @@ func domainSubnetworkIDToProto(id *externalapi.DomainSubnetworkID) *SubnetworkId
}

func (x *NetAddress) toAppMessage() (*appmessage.NetAddress, error) {
if x == nil {
return nil, errors.Wrap(errorNil, "NetAddress is nil")
}
if x.Port > math.MaxUint16 {
return nil, errors.Errorf("port number is larger than %d", math.MaxUint16)
}
Expand All @@ -108,3 +119,46 @@ func appMessageNetAddressToProto(address *appmessage.NetAddress) *NetAddress {
Port: uint32(address.Port),
}
}

func (x *Outpoint) toAppMessage() (*appmessage.Outpoint, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "Outpoint is nil")
}
transactionID, err := x.TransactionId.toDomain()
if err != nil {
return nil, err
}
return &appmessage.Outpoint{
TxID: *transactionID,
Index: x.Index,
}, nil
}

func (x *UtxoEntry) toAppMessage() (*appmessage.UTXOEntry, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "UtxoEntry is nil")
}
scriptPublicKey, err := x.ScriptPublicKey.toAppMessage()
if err != nil {
return nil, err
}
return &appmessage.UTXOEntry{
Amount: x.Amount,
ScriptPublicKey: scriptPublicKey,
BlockBlueScore: x.BlockBlueScore,
IsCoinbase: x.IsCoinbase,
}, nil
}

func (x *ScriptPublicKey) toAppMessage() (*externalapi.ScriptPublicKey, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "ScriptPublicKey is nil")
}
if x.Version > math.MaxUint16 {
return nil, errors.Errorf("ScriptPublicKey version is bigger then uint16.")
}
return &externalapi.ScriptPublicKey{
Script: x.Script,
Version: uint16(x.Version),
}, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,36 @@ import (
)

func (x *KaspadMessage_Addresses) toAppMessage() (appmessage.Message, error) {
protoAddresses := x.Addresses
if len(x.Addresses.AddressList) > appmessage.MaxAddressesPerMsg {
return nil, errors.Errorf("too many addresses for message "+
"[count %d, max %d]", len(x.Addresses.AddressList), appmessage.MaxAddressesPerMsg)
if x == nil {
return nil, errors.Wrap(errorNil, "KaspadMessage_Addresses is nil")
}
addressList, err := x.Addresses.toAppMessage()
if err != nil {
return nil, err
}
return &appmessage.MsgAddresses{
AddressList: addressList,
}, nil
}

addressList := make([]*appmessage.NetAddress, len(protoAddresses.AddressList))
for i, address := range protoAddresses.AddressList {
func (x *AddressesMessage) toAppMessage() ([]*appmessage.NetAddress, error) {
if x == nil {
return nil, errors.Wrap(errorNil, "AddressesMessage is nil")
}

if len(x.AddressList) > appmessage.MaxAddressesPerMsg {
return nil, errors.Errorf("too many addresses for message "+
"[count %d, max %d]", len(x.AddressList), appmessage.MaxAddressesPerMsg)
}
addressList := make([]*appmessage.NetAddress, len(x.AddressList))
for i, address := range x.AddressList {
var err error
addressList[i], err = address.toAppMessage()
if err != nil {
return nil, err
}
}
return &appmessage.MsgAddresses{
AddressList: addressList,
}, nil
return addressList, nil
}

func (x *KaspadMessage_Addresses) fromAppMessage(msgAddresses *appmessage.MsgAddresses) error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
)

func (x *KaspadMessage_Block) toAppMessage() (appmessage.Message, error) {
if x == nil {
return nil, errors.Wrap(errorNil, "KaspadMessage_Block is nil")
}
return x.Block.toAppMessage()
}

Expand All @@ -15,21 +18,14 @@ func (x *KaspadMessage_Block) fromAppMessage(msgBlock *appmessage.MsgBlock) erro
}

func (x *BlockMessage) toAppMessage() (*appmessage.MsgBlock, error) {
if x == nil {
return nil, errors.Wrap(errorNil, "BlockMessage is nil")
}
if len(x.Transactions) > appmessage.MaxTxPerBlock {
return nil, errors.Errorf("too many transactions to fit into a block "+
"[count %d, max %d]", len(x.Transactions), appmessage.MaxTxPerBlock)
}

protoBlockHeader := x.Header
if protoBlockHeader == nil {
return nil, errors.New("block header field cannot be nil")
}

if len(protoBlockHeader.ParentHashes) > appmessage.MaxBlockParents {
return nil, errors.Errorf("block header has %d parents, but the maximum allowed amount "+
"is %d", len(protoBlockHeader.ParentHashes), appmessage.MaxBlockParents)
}

header, err := x.Header.toAppMessage()
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,37 @@
package protowire

import "github.com/kaspanet/kaspad/app/appmessage"
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/pkg/errors"
)

func (x *KaspadMessage_BlockHeaders) toAppMessage() (appmessage.Message, error) {
blockHeaders := make([]*appmessage.MsgBlockHeader, len(x.BlockHeaders.BlockHeaders))
for i, blockHeader := range x.BlockHeaders.BlockHeaders {
if x == nil {
return nil, errors.Wrapf(errorNil, "KaspadMessage_BlockHeaders is nil")
}
blockHeaders, err := x.BlockHeaders.toAppMessage()
if err != nil {
return nil, err
}
return &appmessage.BlockHeadersMessage{
BlockHeaders: blockHeaders,
}, nil
}

func (x *BlockHeadersMessage) toAppMessage() ([]*appmessage.MsgBlockHeader, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "BlockHeadersMessage is nil")
}
blockHeaders := make([]*appmessage.MsgBlockHeader, len(x.BlockHeaders))
for i, blockHeader := range x.BlockHeaders {
var err error
blockHeaders[i], err = blockHeader.toAppMessage()
if err != nil {
return nil, err
}
}

return &appmessage.BlockHeadersMessage{
BlockHeaders: blockHeaders,
}, nil
return blockHeaders, nil
}

func (x *KaspadMessage_BlockHeaders) fromAppMessage(blockHeadersMessage *appmessage.BlockHeadersMessage) error {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@ package protowire

import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/kaspanet/kaspad/domain/consensus/model/externalapi"
"github.com/pkg/errors"
)

func (x *KaspadMessage_BlockLocator) toAppMessage() (appmessage.Message, error) {
if len(x.BlockLocator.Hashes) > appmessage.MaxBlockLocatorsPerMsg {
return nil, errors.Errorf("too many block locator hashes for message "+
"[count %d, max %d]", len(x.BlockLocator.Hashes), appmessage.MaxBlockLocatorsPerMsg)
if x == nil {
return nil, errors.Wrapf(errorNil, "KaspadMessage_BlockLocator is nil")
}
hashes, err := protoHashesToDomain(x.BlockLocator.Hashes)
hashes, err := x.BlockLocator.toAppMessage()
if err != nil {
return nil, err
}
return &appmessage.MsgBlockLocator{BlockLocatorHashes: hashes}, nil
}

func (x *BlockLocatorMessage) toAppMessage() ([]*externalapi.DomainHash, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "BlockLocatorMessage is nil")
}
if len(x.Hashes) > appmessage.MaxBlockLocatorsPerMsg {
return nil, errors.Errorf("too many block locator hashes for message "+
"[count %d, max %d]", len(x.Hashes), appmessage.MaxBlockLocatorsPerMsg)
}
return protoHashesToDomain(x.Hashes)
}

func (x *KaspadMessage_BlockLocator) fromAppMessage(msgBlockLocator *appmessage.MsgBlockLocator) error {
if len(msgBlockLocator.BlockLocatorHashes) > appmessage.MaxBlockLocatorsPerMsg {
return errors.Errorf("too many block locator hashes for message "+
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package protowire

import "github.com/kaspanet/kaspad/app/appmessage"
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/pkg/errors"
)

func (x *KaspadMessage_DoneHeaders) toAppMessage() (appmessage.Message, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "KaspadMessage_DoneHeaders is nil")
}
return &appmessage.MsgDoneHeaders{}, nil
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package protowire

import "github.com/kaspanet/kaspad/app/appmessage"
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/pkg/errors"
)

func (x *KaspadMessage_DonePruningPointUtxoSetChunks) toAppMessage() (appmessage.Message, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "KaspadMessage_DonePruningPointUtxoSetChunks is nil")
}
return &appmessage.MsgDonePruningPointUTXOSetChunks{}, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
)

func (x *BlockHeaderMessage) toAppMessage() (*appmessage.MsgBlockHeader, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "BlockHeaderMessage is nil")
}
if len(x.ParentHashes) > appmessage.MaxBlockParents {
return nil, errors.Errorf("block header has %d parents, but the maximum allowed amount "+
"is %d", len(x.ParentHashes), appmessage.MaxBlockParents)
Expand All @@ -17,17 +20,14 @@ func (x *BlockHeaderMessage) toAppMessage() (*appmessage.MsgBlockHeader, error)
if err != nil {
return nil, err
}

hashMerkleRoot, err := x.HashMerkleRoot.toDomain()
if err != nil {
return nil, err
}

acceptedIDMerkleRoot, err := x.AcceptedIdMerkleRoot.toDomain()
if err != nil {
return nil, err
}

utxoCommitment, err := x.UtxoCommitment.toDomain()
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package protowire

import "github.com/kaspanet/kaspad/app/appmessage"
import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/pkg/errors"
)

func (x *KaspadMessage_IbdBlock) toAppMessage() (appmessage.Message, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "KaspadMessage_IbdBlock is nil")
}
msgBlock, err := x.IbdBlock.toAppMessage()
if err != nil {
return nil, err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,25 @@ package protowire

import (
"github.com/kaspanet/kaspad/app/appmessage"
"github.com/pkg/errors"
)

func (x *KaspadMessage_IbdBlockLocator) toAppMessage() (appmessage.Message, error) {
targetHash, err := x.IbdBlockLocator.TargetHash.toDomain()
if x == nil {
return nil, errors.Wrapf(errorNil, "KaspadMessage_IbdBlockLocator is nil")
}
return x.IbdBlockLocator.toAppMessage()
}

func (x *IbdBlockLocatorMessage) toAppMessage() (appmessage.Message, error) {
if x == nil {
return nil, errors.Wrapf(errorNil, "IbdBlockLocatorMessage is nil")
}
targetHash, err := x.TargetHash.toDomain()
if err != nil {
return nil, err
}
blockLocatorHash, err := protoHashesToDomain(x.IbdBlockLocator.BlockLocatorHashes)
blockLocatorHash, err := protoHashesToDomain(x.BlockLocatorHashes)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 7829a9f

Please sign in to comment.